summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.buildkite/postgres-config.yaml2
-rw-r--r--.buildkite/sqlite-config.yaml2
-rw-r--r--CHANGES.md78
-rw-r--r--changelog.d/5815.feature1
-rw-r--r--changelog.d/5858.feature1
-rw-r--r--changelog.d/5925.feature1
-rw-r--r--changelog.d/5925.removal1
-rw-r--r--changelog.d/6119.feature1
-rw-r--r--changelog.d/6176.feature1
-rw-r--r--changelog.d/6237.bugfix1
-rw-r--r--changelog.d/6241.bugfix1
-rw-r--r--changelog.d/6266.misc1
-rw-r--r--changelog.d/6322.misc1
-rw-r--r--changelog.d/6329.bugfix1
-rw-r--r--changelog.d/6332.bugfix1
-rw-r--r--changelog.d/6333.bugfix1
-rw-r--r--changelog.d/6343.misc1
-rw-r--r--changelog.d/6354.feature1
-rw-r--r--changelog.d/6362.misc1
-rw-r--r--changelog.d/6369.doc1
-rw-r--r--changelog.d/6379.misc1
-rw-r--r--changelog.d/6388.doc1
-rw-r--r--changelog.d/6390.doc1
-rw-r--r--changelog.d/6392.misc1
-rw-r--r--changelog.d/6406.bugfix1
-rw-r--r--changelog.d/6408.bugfix1
-rw-r--r--changelog.d/6409.feature1
-rw-r--r--changelog.d/6411.feature1
-rw-r--r--changelog.d/6420.bugfix1
-rw-r--r--changelog.d/6421.bugfix1
-rw-r--r--changelog.d/6423.misc1
-rw-r--r--changelog.d/6426.bugfix1
-rw-r--r--changelog.d/6429.misc1
-rw-r--r--changelog.d/6434.feature1
-rw-r--r--changelog.d/6436.bugfix1
-rw-r--r--changelog.d/6443.doc1
-rw-r--r--changelog.d/6449.bugfix1
-rw-r--r--changelog.d/6451.bugfix1
-rw-r--r--changelog.d/6454.misc1
-rw-r--r--changelog.d/6458.doc1
-rw-r--r--changelog.d/6461.doc1
-rw-r--r--changelog.d/6462.bugfix1
-rw-r--r--changelog.d/6464.misc1
-rw-r--r--changelog.d/6468.misc1
-rw-r--r--changelog.d/6470.bugfix1
-rw-r--r--changelog.d/6472.bugfix1
-rw-r--r--changelog.d/6480.misc1
-rw-r--r--changelog.d/6482.misc1
-rw-r--r--changelog.d/6483.misc1
-rw-r--r--changelog.d/6497.bugfix1
-rw-r--r--changelog.d/6499.bugfix1
-rw-r--r--changelog.d/6501.misc1
-rw-r--r--changelog.d/6502.removal1
-rw-r--r--changelog.d/6503.misc1
-rw-r--r--changelog.d/6505.misc1
-rw-r--r--changelog.d/6506.misc1
-rw-r--r--changelog.d/6510.misc1
-rw-r--r--changelog.d/6512.misc1
-rw-r--r--changelog.d/6514.bugfix1
-rw-r--r--contrib/systemd/matrix-synapse.service2
-rw-r--r--docs/saml_mapping_providers.md77
-rw-r--r--docs/sample_config.yaml61
-rw-r--r--mypy.ini2
-rwxr-xr-xscripts-dev/update_database17
-rwxr-xr-xscripts/synapse_port_db119
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/app/federation_sender.py5
-rw-r--r--synapse/app/homeserver.py43
-rw-r--r--synapse/app/user_dir.py7
-rw-r--r--synapse/config/saml2_config.py186
-rw-r--r--synapse/event_auth.py8
-rw-r--r--synapse/federation/federation_client.py164
-rw-r--r--synapse/federation/transport/client.py24
-rw-r--r--synapse/handlers/e2e_keys.py38
-rw-r--r--synapse/handlers/events.py30
-rw-r--r--synapse/handlers/federation.py101
-rw-r--r--synapse/handlers/initial_sync.py19
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/handlers/saml_handler.py198
-rw-r--r--synapse/handlers/sync.py251
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/logging/context.py11
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/notifier.py29
-rw-r--r--synapse/replication/slave/storage/_base.py5
-rw-r--r--synapse/replication/slave/storage/account_data.py5
-rw-r--r--synapse/replication/slave/storage/client_ips.py5
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py5
-rw-r--r--synapse/replication/slave/storage/devices.py5
-rw-r--r--synapse/replication/slave/storage/events.py5
-rw-r--r--synapse/replication/slave/storage/filtering.py5
-rw-r--r--synapse/replication/slave/storage/groups.py5
-rw-r--r--synapse/replication/slave/storage/presence.py5
-rw-r--r--synapse/replication/slave/storage/push_rule.py5
-rw-r--r--synapse/replication/slave/storage/pushers.py5
-rw-r--r--synapse/replication/slave/storage/receipts.py5
-rw-r--r--synapse/replication/slave/storage/room.py5
-rw-r--r--synapse/rest/client/v1/profile.py11
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py2
-rw-r--r--synapse/server.py3
-rw-r--r--synapse/storage/__init__.py20
-rw-r--r--synapse/storage/_base.py1478
-rw-r--r--synapse/storage/background_updates.py31
-rw-r--r--synapse/storage/data_stores/__init__.py16
-rw-r--r--synapse/storage/data_stores/main/__init__.py57
-rw-r--r--synapse/storage/data_stores/main/account_data.py37
-rw-r--r--synapse/storage/data_stores/main/appservice.py29
-rw-r--r--synapse/storage/data_stores/main/cache.py6
-rw-r--r--synapse/storage/data_stores/main/client_ips.py75
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py38
-rw-r--r--synapse/storage/data_stores/main/devices.py94
-rw-r--r--synapse/storage/data_stores/main/directory.py20
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py28
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py44
-rw-r--r--synapse/storage/data_stores/main/event_federation.py51
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py55
-rw-r--r--synapse/storage/data_stores/main/events.py120
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py86
-rw-r--r--synapse/storage/data_stores/main/events_worker.py23
-rw-r--r--synapse/storage/data_stores/main/filtering.py4
-rw-r--r--synapse/storage/data_stores/main/group_server.py164
-rw-r--r--synapse/storage/data_stores/main/keys.py12
-rw-r--r--synapse/storage/data_stores/main/media_repository.py61
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py19
-rw-r--r--synapse/storage/data_stores/main/openid.py6
-rw-r--r--synapse/storage/data_stores/main/presence.py12
-rw-r--r--synapse/storage/data_stores/main/profile.py28
-rw-r--r--synapse/storage/data_stores/main/push_rule.py41
-rw-r--r--synapse/storage/data_stores/main/pusher.py34
-rw-r--r--synapse/storage/data_stores/main/receipts.py45
-rw-r--r--synapse/storage/data_stores/main/registration.py198
-rw-r--r--synapse/storage/data_stores/main/rejections.py4
-rw-r--r--synapse/storage/data_stores/main/relations.py12
-rw-r--r--synapse/storage/data_stores/main/room.py92
-rw-r--r--synapse/storage/data_stores/main/roommember.py84
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql4
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql16
-rw-r--r--synapse/storage/data_stores/main/search.py70
-rw-r--r--synapse/storage/data_stores/main/signatures.py4
-rw-r--r--synapse/storage/data_stores/main/state.py86
-rw-r--r--synapse/storage/data_stores/main/state_deltas.py8
-rw-r--r--synapse/storage/data_stores/main/stats.py73
-rw-r--r--synapse/storage/data_stores/main/stream.py35
-rw-r--r--synapse/storage/data_stores/main/tags.py18
-rw-r--r--synapse/storage/data_stores/main/transactions.py27
-rw-r--r--synapse/storage/data_stores/main/user_directory.py122
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py6
-rw-r--r--synapse/storage/database.py1490
-rw-r--r--synapse/util/caches/snapshot_cache.py94
-rw-r--r--synapse/util/metrics.py83
-rw-r--r--synmark/__init__.py6
-rw-r--r--sytest-blacklist3
-rw-r--r--tests/handlers/test_e2e_keys.py8
-rw-r--r--tests/handlers/test_stats.py80
-rw-r--r--tests/handlers/test_sync.py33
-rw-r--r--tests/handlers/test_typing.py24
-rw-r--r--tests/handlers/test_user_directory.py30
-rw-r--r--tests/replication/slave/storage/_base.py5
-rw-r--r--tests/rest/admin/test_admin.py2
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/storage/test__base.py16
-rw-r--r--tests/storage/test_appservice.py17
-rw-r--r--tests/storage/test_background_update.py28
-rw-r--r--tests/storage/test_base.py21
-rw-r--r--tests/storage/test_cleanup_extrems.py18
-rw-r--r--tests/storage/test_client_ips.py87
-rw-r--r--tests/storage/test_event_federation.py8
-rw-r--r--tests/storage/test_event_push_actions.py12
-rw-r--r--tests/storage/test_monthly_active_users.py6
-rw-r--r--tests/storage/test_profile.py3
-rw-r--r--tests/storage/test_redaction.py4
-rw-r--r--tests/storage/test_roommember.py20
-rw-r--r--tests/storage/test_user_directory.py4
-rw-r--r--tests/test_federation.py16
-rw-r--r--tests/unittest.py19
-rw-r--r--tests/util/test_logcontext.py24
-rw-r--r--tests/util/test_snapshot_cache.py63
179 files changed, 3891 insertions, 3554 deletions
diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml
index a35fec394d..2acbe66f4c 100644
--- a/.buildkite/postgres-config.yaml
+++ b/.buildkite/postgres-config.yaml
@@ -1,7 +1,7 @@
 # Configuration file used for testing the 'synapse_port_db' script.
 # Tells the script to connect to the postgresql database that will be available in the
 # CI's Docker setup at the point where this file is considered.
-server_name: "test"
+server_name: "localhost:8800"
 
 signing_key_path: "/src/.buildkite/test.signing.key"
 
diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml
index 635b921764..6d9bf80d84 100644
--- a/.buildkite/sqlite-config.yaml
+++ b/.buildkite/sqlite-config.yaml
@@ -1,7 +1,7 @@
 # Configuration file used for testing the 'synapse_port_db' script.
 # Tells the 'update_database' script to connect to the test SQLite database to upgrade its
 # schema and run background updates on it.
-server_name: "test"
+server_name: "localhost:8800"
 
 signing_key_path: "/src/.buildkite/test.signing.key"
 
diff --git a/CHANGES.md b/CHANGES.md
index a9afd36d2c..0ef9794aac 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,81 @@
+Synapse 1.7.0rc1 (2019-12-09)
+=============================
+
+Features
+--------
+
+- Implement per-room message retention policies. ([\#5815](https://github.com/matrix-org/synapse/issues/5815))
+- Add etag and count fields to key backup endpoints to help clients guess if there are new keys. ([\#5858](https://github.com/matrix-org/synapse/issues/5858))
+- Add admin/v2/users endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925))
+- Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. ([\#6119](https://github.com/matrix-org/synapse/issues/6119))
+- Implement the `/_matrix/federation/unstable/net.atleastfornow/state/<context>` API as drafted in MSC2314. ([\#6176](https://github.com/matrix-org/synapse/issues/6176))
+- Configure privacy preserving settings by default for the room directory. ([\#6354](https://github.com/matrix-org/synapse/issues/6354))
+- Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). ([\#6409](https://github.com/matrix-org/synapse/issues/6409))
+- Add support for MSC 2367, which allows specifying a reason on all membership events. ([\#6434](https://github.com/matrix-org/synapse/issues/6434))
+
+
+Bugfixes
+--------
+
+- Transfer non-standard power levels on room upgrade. ([\#6237](https://github.com/matrix-org/synapse/issues/6237))
+- Fix error from the Pillow library when uploading RGBA images. ([\#6241](https://github.com/matrix-org/synapse/issues/6241))
+- Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests. ([\#6329](https://github.com/matrix-org/synapse/issues/6329))
+- Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. ([\#6332](https://github.com/matrix-org/synapse/issues/6332))
+- Prevent account data syncs getting lost across TCP replication. ([\#6333](https://github.com/matrix-org/synapse/issues/6333))
+- Fix bug: TypeError in `register_user()` while using LDAP auth module. ([\#6406](https://github.com/matrix-org/synapse/issues/6406))
+- Fix an intermittent exception when handling read-receipts. ([\#6408](https://github.com/matrix-org/synapse/issues/6408))
+- Fix broken guest registration when there are existing blocks of numeric user IDs. ([\#6420](https://github.com/matrix-org/synapse/issues/6420))
+- Fix startup error when http proxy is defined. ([\#6421](https://github.com/matrix-org/synapse/issues/6421))
+- Clean up local threepids from user on account deactivation. ([\#6426](https://github.com/matrix-org/synapse/issues/6426))
+- Fix a bug where a room could become unusable with a low retention policy and a low activity. ([\#6436](https://github.com/matrix-org/synapse/issues/6436))
+- Fix error when using synapse_port_db on a vanilla synapse db. ([\#6449](https://github.com/matrix-org/synapse/issues/6449))
+- Fix uploading multiple cross signing signatures for the same user. ([\#6451](https://github.com/matrix-org/synapse/issues/6451))
+- Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted. ([\#6462](https://github.com/matrix-org/synapse/issues/6462))
+- Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process. ([\#6470](https://github.com/matrix-org/synapse/issues/6470))
+- Improve sanity-checking when receiving events over federation. ([\#6472](https://github.com/matrix-org/synapse/issues/6472))
+- Fix inaccurate per-block Prometheus metrics. ([\#6491](https://github.com/matrix-org/synapse/issues/6491))
+- Fix small performance regression for sending invites. ([\#6493](https://github.com/matrix-org/synapse/issues/6493))
+- Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. ([\#6494](https://github.com/matrix-org/synapse/issues/6494))
+
+
+Improved Documentation
+----------------------
+
+- Update documentation and variables in user contributed systemd reference file. ([\#6369](https://github.com/matrix-org/synapse/issues/6369), [\#6490](https://github.com/matrix-org/synapse/issues/6490))
+- Fix link in the user directory documentation. ([\#6388](https://github.com/matrix-org/synapse/issues/6388))
+- Add build instructions to the docker readme. ([\#6390](https://github.com/matrix-org/synapse/issues/6390))
+- Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md. ([\#6443](https://github.com/matrix-org/synapse/issues/6443))
+- Write some docs for the quarantine_media api. ([\#6458](https://github.com/matrix-org/synapse/issues/6458))
+- Convert CONTRIBUTING.rst to markdown (among other small fixes). ([\#6461](https://github.com/matrix-org/synapse/issues/6461))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925))
+- Remove fallback for federation with old servers which lack the /federation/v1/state_ids API. ([\#6488](https://github.com/matrix-org/synapse/issues/6488))
+
+
+Internal Changes
+----------------
+
+- Add benchmarks for structured logging and improve output performance. ([\#6266](https://github.com/matrix-org/synapse/issues/6266))
+- Improve the performance of outputting structured logging. ([\#6322](https://github.com/matrix-org/synapse/issues/6322))
+- Refactor some code in the event authentication path for clarity. ([\#6343](https://github.com/matrix-org/synapse/issues/6343), [\#6468](https://github.com/matrix-org/synapse/issues/6468), [\#6480](https://github.com/matrix-org/synapse/issues/6480))
+- Clean up some unnecessary quotation marks around the codebase. ([\#6362](https://github.com/matrix-org/synapse/issues/6362))
+- Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. ([\#6379](https://github.com/matrix-org/synapse/issues/6379))
+- Add a test scenario to make sure room history purges don't break `/messages` in the future. ([\#6392](https://github.com/matrix-org/synapse/issues/6392))
+- Clarifications for the email configuration settings. ([\#6423](https://github.com/matrix-org/synapse/issues/6423))
+- Add more tests to the blacklist when running in worker mode. ([\#6429](https://github.com/matrix-org/synapse/issues/6429))
+- Move data store specific code out of `SQLBaseStore`. ([\#6454](https://github.com/matrix-org/synapse/issues/6454))
+- Prepare SQLBaseStore functions being moved out of the stores. ([\#6464](https://github.com/matrix-org/synapse/issues/6464))
+- Move per database functionality out of the data stores and into a dedicated `Database` class. ([\#6469](https://github.com/matrix-org/synapse/issues/6469))
+- Port synapse.rest.client.v1 to async/await. ([\#6482](https://github.com/matrix-org/synapse/issues/6482))
+- Port synapse.rest.client.v2_alpha to async/await. ([\#6483](https://github.com/matrix-org/synapse/issues/6483))
+- Port SyncHandler to async/await. ([\#6484](https://github.com/matrix-org/synapse/issues/6484))
+- Pass in `Database` object to data stores. ([\#6487](https://github.com/matrix-org/synapse/issues/6487))
+
+
 Synapse 1.6.1 (2019-11-28)
 ==========================
 
diff --git a/changelog.d/5815.feature b/changelog.d/5815.feature
deleted file mode 100644
index ca4df4e7f6..0000000000
--- a/changelog.d/5815.feature
+++ /dev/null
@@ -1 +0,0 @@
-Implement per-room message retention policies.
diff --git a/changelog.d/5858.feature b/changelog.d/5858.feature
deleted file mode 100644
index 55ee93051e..0000000000
--- a/changelog.d/5858.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add etag and count fields to key backup endpoints to help clients guess if there are new keys.
diff --git a/changelog.d/5925.feature b/changelog.d/5925.feature
deleted file mode 100644
index 8025cc8231..0000000000
--- a/changelog.d/5925.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add admin/v2/users endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH.
diff --git a/changelog.d/5925.removal b/changelog.d/5925.removal
deleted file mode 100644
index cbba2855cb..0000000000
--- a/changelog.d/5925.removal
+++ /dev/null
@@ -1 +0,0 @@
-Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.
diff --git a/changelog.d/6119.feature b/changelog.d/6119.feature
deleted file mode 100644
index 1492e83c5a..0000000000
--- a/changelog.d/6119.feature
+++ /dev/null
@@ -1 +0,0 @@
-Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account.
\ No newline at end of file
diff --git a/changelog.d/6176.feature b/changelog.d/6176.feature
deleted file mode 100644
index 3c66d689d4..0000000000
--- a/changelog.d/6176.feature
+++ /dev/null
@@ -1 +0,0 @@
-Implement the `/_matrix/federation/unstable/net.atleastfornow/state/<context>` API as drafted in MSC2314.
diff --git a/changelog.d/6237.bugfix b/changelog.d/6237.bugfix
deleted file mode 100644
index 9285600b00..0000000000
--- a/changelog.d/6237.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Transfer non-standard power levels on room upgrade.
\ No newline at end of file
diff --git a/changelog.d/6241.bugfix b/changelog.d/6241.bugfix
deleted file mode 100644
index 25109ca4a6..0000000000
--- a/changelog.d/6241.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix error from the Pillow library when uploading RGBA images.
diff --git a/changelog.d/6266.misc b/changelog.d/6266.misc
deleted file mode 100644
index 634e421a79..0000000000
--- a/changelog.d/6266.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add benchmarks for structured logging and improve output performance.
diff --git a/changelog.d/6322.misc b/changelog.d/6322.misc
deleted file mode 100644
index 70ef36ca80..0000000000
--- a/changelog.d/6322.misc
+++ /dev/null
@@ -1 +0,0 @@
-Improve the performance of outputting structured logging.
diff --git a/changelog.d/6329.bugfix b/changelog.d/6329.bugfix
deleted file mode 100644
index e558d13b7d..0000000000
--- a/changelog.d/6329.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests.
\ No newline at end of file
diff --git a/changelog.d/6332.bugfix b/changelog.d/6332.bugfix
deleted file mode 100644
index 67d5170ba0..0000000000
--- a/changelog.d/6332.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices.
diff --git a/changelog.d/6333.bugfix b/changelog.d/6333.bugfix
deleted file mode 100644
index a25d6ef3cb..0000000000
--- a/changelog.d/6333.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Prevent account data syncs getting lost across TCP replication.
\ No newline at end of file
diff --git a/changelog.d/6343.misc b/changelog.d/6343.misc
deleted file mode 100644
index d9a44389b9..0000000000
--- a/changelog.d/6343.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor some code in the event authentication path for clarity.
diff --git a/changelog.d/6354.feature b/changelog.d/6354.feature
deleted file mode 100644
index fed9db884b..0000000000
--- a/changelog.d/6354.feature
+++ /dev/null
@@ -1 +0,0 @@
-Configure privacy preserving settings by default for the room directory.
diff --git a/changelog.d/6362.misc b/changelog.d/6362.misc
deleted file mode 100644
index b79a5bea99..0000000000
--- a/changelog.d/6362.misc
+++ /dev/null
@@ -1 +0,0 @@
-Clean up some unnecessary quotation marks around the codebase.
\ No newline at end of file
diff --git a/changelog.d/6369.doc b/changelog.d/6369.doc
deleted file mode 100644
index 6db351d7db..0000000000
--- a/changelog.d/6369.doc
+++ /dev/null
@@ -1 +0,0 @@
-Update documentation and variables in user contributed systemd reference file.
diff --git a/changelog.d/6379.misc b/changelog.d/6379.misc
deleted file mode 100644
index 725c2e7d87..0000000000
--- a/changelog.d/6379.misc
+++ /dev/null
@@ -1 +0,0 @@
-Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary.
\ No newline at end of file
diff --git a/changelog.d/6388.doc b/changelog.d/6388.doc
deleted file mode 100644
index c777cb6b8f..0000000000
--- a/changelog.d/6388.doc
+++ /dev/null
@@ -1 +0,0 @@
-Fix link in the user directory documentation.
diff --git a/changelog.d/6390.doc b/changelog.d/6390.doc
deleted file mode 100644
index 093411bec1..0000000000
--- a/changelog.d/6390.doc
+++ /dev/null
@@ -1 +0,0 @@
-Add build instructions to the docker readme.
\ No newline at end of file
diff --git a/changelog.d/6392.misc b/changelog.d/6392.misc
deleted file mode 100644
index a00257944f..0000000000
--- a/changelog.d/6392.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add a test scenario to make sure room history purges don't break `/messages` in the future.
diff --git a/changelog.d/6406.bugfix b/changelog.d/6406.bugfix
deleted file mode 100644
index ca9bee084b..0000000000
--- a/changelog.d/6406.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix bug: TypeError in `register_user()` while using LDAP auth module.
diff --git a/changelog.d/6408.bugfix b/changelog.d/6408.bugfix
deleted file mode 100644
index c9babe599b..0000000000
--- a/changelog.d/6408.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix an intermittent exception when handling read-receipts.
diff --git a/changelog.d/6409.feature b/changelog.d/6409.feature
deleted file mode 100644
index 653ff5a5ad..0000000000
--- a/changelog.d/6409.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228).
diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature
new file mode 100644
index 0000000000..ebea4a208d
--- /dev/null
+++ b/changelog.d/6411.feature
@@ -0,0 +1 @@
+Allow custom SAML username mapping functinality through an external provider plugin.
\ No newline at end of file
diff --git a/changelog.d/6420.bugfix b/changelog.d/6420.bugfix
deleted file mode 100644
index aef47cccaa..0000000000
--- a/changelog.d/6420.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix broken guest registration when there are existing blocks of numeric user IDs.
diff --git a/changelog.d/6421.bugfix b/changelog.d/6421.bugfix
deleted file mode 100644
index 7969f7f71d..0000000000
--- a/changelog.d/6421.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix startup error when http proxy is defined.
diff --git a/changelog.d/6423.misc b/changelog.d/6423.misc
deleted file mode 100644
index 9bcd5d36c1..0000000000
--- a/changelog.d/6423.misc
+++ /dev/null
@@ -1 +0,0 @@
-Clarifications for the email configuration settings.
diff --git a/changelog.d/6426.bugfix b/changelog.d/6426.bugfix
deleted file mode 100644
index 3acfde4211..0000000000
--- a/changelog.d/6426.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Clean up local threepids from user on account deactivation.
\ No newline at end of file
diff --git a/changelog.d/6429.misc b/changelog.d/6429.misc
deleted file mode 100644
index 4b32cdeac6..0000000000
--- a/changelog.d/6429.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add more tests to the blacklist when running in worker mode.
diff --git a/changelog.d/6434.feature b/changelog.d/6434.feature
deleted file mode 100644
index affa5d50c1..0000000000
--- a/changelog.d/6434.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add support for MSC 2367, which allows specifying a reason on all membership events.
diff --git a/changelog.d/6436.bugfix b/changelog.d/6436.bugfix
deleted file mode 100644
index 954a4e1d84..0000000000
--- a/changelog.d/6436.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug where a room could become unusable with a low retention policy and a low activity.
diff --git a/changelog.d/6443.doc b/changelog.d/6443.doc
deleted file mode 100644
index 67c59f92ee..0000000000
--- a/changelog.d/6443.doc
+++ /dev/null
@@ -1 +0,0 @@
-Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md.
\ No newline at end of file
diff --git a/changelog.d/6449.bugfix b/changelog.d/6449.bugfix
deleted file mode 100644
index 002f33c450..0000000000
--- a/changelog.d/6449.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix error when using synapse_port_db on a vanilla synapse db.
diff --git a/changelog.d/6451.bugfix b/changelog.d/6451.bugfix
deleted file mode 100644
index 23b67583ec..0000000000
--- a/changelog.d/6451.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix uploading multiple cross signing signatures for the same user.
diff --git a/changelog.d/6454.misc b/changelog.d/6454.misc
deleted file mode 100644
index 9e5259157c..0000000000
--- a/changelog.d/6454.misc
+++ /dev/null
@@ -1 +0,0 @@
-Move data store specific code out of `SQLBaseStore`.
diff --git a/changelog.d/6458.doc b/changelog.d/6458.doc
deleted file mode 100644
index 3a9f831d89..0000000000
--- a/changelog.d/6458.doc
+++ /dev/null
@@ -1 +0,0 @@
-Write some docs for the quarantine_media api.
diff --git a/changelog.d/6461.doc b/changelog.d/6461.doc
deleted file mode 100644
index 1502fa2855..0000000000
--- a/changelog.d/6461.doc
+++ /dev/null
@@ -1 +0,0 @@
-Convert CONTRIBUTING.rst to markdown (among other small fixes).
\ No newline at end of file
diff --git a/changelog.d/6462.bugfix b/changelog.d/6462.bugfix
deleted file mode 100644
index c435939526..0000000000
--- a/changelog.d/6462.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted.
diff --git a/changelog.d/6464.misc b/changelog.d/6464.misc
deleted file mode 100644
index bd65276ef6..0000000000
--- a/changelog.d/6464.misc
+++ /dev/null
@@ -1 +0,0 @@
-Prepare SQLBaseStore functions being moved out of the stores.
diff --git a/changelog.d/6468.misc b/changelog.d/6468.misc
deleted file mode 100644
index d9a44389b9..0000000000
--- a/changelog.d/6468.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor some code in the event authentication path for clarity.
diff --git a/changelog.d/6470.bugfix b/changelog.d/6470.bugfix
deleted file mode 100644
index c08b34c14c..0000000000
--- a/changelog.d/6470.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process.
diff --git a/changelog.d/6472.bugfix b/changelog.d/6472.bugfix
deleted file mode 100644
index 598efb79fc..0000000000
--- a/changelog.d/6472.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Improve sanity-checking when receiving events over federation.
diff --git a/changelog.d/6480.misc b/changelog.d/6480.misc
deleted file mode 100644
index d9a44389b9..0000000000
--- a/changelog.d/6480.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor some code in the event authentication path for clarity.
diff --git a/changelog.d/6482.misc b/changelog.d/6482.misc
deleted file mode 100644
index bdef9cf40a..0000000000
--- a/changelog.d/6482.misc
+++ /dev/null
@@ -1 +0,0 @@
-Port synapse.rest.client.v1 to async/await.
diff --git a/changelog.d/6483.misc b/changelog.d/6483.misc
deleted file mode 100644
index cb2cd2bc39..0000000000
--- a/changelog.d/6483.misc
+++ /dev/null
@@ -1 +0,0 @@
-Port synapse.rest.client.v2_alpha to async/await.
diff --git a/changelog.d/6497.bugfix b/changelog.d/6497.bugfix
new file mode 100644
index 0000000000..92ed08fc40
--- /dev/null
+++ b/changelog.d/6497.bugfix
@@ -0,0 +1 @@
+Fix error message when setting your profile's avatar URL mentioning displaynames, and prevent NoneType avatar_urls.
\ No newline at end of file
diff --git a/changelog.d/6499.bugfix b/changelog.d/6499.bugfix
new file mode 100644
index 0000000000..299feba0f8
--- /dev/null
+++ b/changelog.d/6499.bugfix
@@ -0,0 +1 @@
+Fix support for SQLite 3.7.
diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc
new file mode 100644
index 0000000000..255f45a9c3
--- /dev/null
+++ b/changelog.d/6501.misc
@@ -0,0 +1 @@
+Refactor get_events_from_store_or_dest to return a dict.
diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal
new file mode 100644
index 0000000000..0b72261d58
--- /dev/null
+++ b/changelog.d/6502.removal
@@ -0,0 +1 @@
+Remove redundant code from event authorisation implementation.
diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc
new file mode 100644
index 0000000000..e4e9a5a3d4
--- /dev/null
+++ b/changelog.d/6503.misc
@@ -0,0 +1 @@
+Move get_state methods into FederationHandler.
diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc
new file mode 100644
index 0000000000..3a75b2d9dd
--- /dev/null
+++ b/changelog.d/6505.misc
@@ -0,0 +1 @@
+Make `make_deferred_yieldable` to work with async/await.
diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc
new file mode 100644
index 0000000000..99d7a70bcf
--- /dev/null
+++ b/changelog.d/6506.misc
@@ -0,0 +1 @@
+Remove `SnapshotCache` in favour of `ResponseCache`.
diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc
new file mode 100644
index 0000000000..214f06539b
--- /dev/null
+++ b/changelog.d/6510.misc
@@ -0,0 +1 @@
+Change phone home stats to not assume there is a single database and report information about the database used by the main data store.
diff --git a/changelog.d/6512.misc b/changelog.d/6512.misc
new file mode 100644
index 0000000000..37a8099eec
--- /dev/null
+++ b/changelog.d/6512.misc
@@ -0,0 +1 @@
+Silence mypy errors for files outside those specified.
diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix
new file mode 100644
index 0000000000..6dc1985c24
--- /dev/null
+++ b/changelog.d/6514.bugfix
@@ -0,0 +1 @@
+Fix race which occasionally caused deleted devices to reappear.
diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service
index bd492544b6..813717b032 100644
--- a/contrib/systemd/matrix-synapse.service
+++ b/contrib/systemd/matrix-synapse.service
@@ -25,7 +25,7 @@ Restart=on-abort
 User=synapse
 Group=nogroup
 
-WorkingDirectory=/opt/synapse
+WorkingDirectory=/home/synapse/synapse
 ExecStart=/home/synapse/synapse/env/bin/python -m synapse.app.homeserver --config-path=/home/synapse/synapse/homeserver.yaml
 SyslogIdentifier=matrix-synapse
 
diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md
new file mode 100644
index 0000000000..92f2380488
--- /dev/null
+++ b/docs/saml_mapping_providers.md
@@ -0,0 +1,77 @@
+# SAML Mapping Providers
+
+A SAML mapping provider is a Python class (loaded via a Python module) that
+works out how to map attributes of a SAML response object to Matrix-specific
+user attributes. Details such as user ID localpart, displayname, and even avatar
+URLs are all things that can be mapped from talking to a SSO service.
+
+As an example, a SSO service may return the email address
+"john.smith@example.com" for a user, whereas Synapse will need to figure out how
+to turn that into a displayname when creating a Matrix user for this individual.
+It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
+variations. As each Synapse configuration may want something different, this is
+where SAML mapping providers come into play.
+
+## Enabling Providers
+
+External mapping providers are provided to Synapse in the form of an external
+Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
+then tell Synapse where to look for the handler class by editing the
+`saml2_config.user_mapping_provider.module` config option.
+
+`saml2_config.user_mapping_provider.config` allows you to provide custom
+configuration options to the module. Check with the module's documentation for
+what options it provides (if any). The options listed by default are for the
+user mapping provider built in to Synapse. If using a custom module, you should
+comment these options out and use those specified by the module instead.
+
+## Building a Custom Mapping Provider
+
+A custom mapping provider must specify the following methods:
+
+* `__init__(self, parsed_config)`
+   - Arguments:
+     - `parsed_config` - A configuration object that is the return value of the
+       `parse_config` method. You should set any configuration options needed by
+       the module here.
+* `saml_response_to_user_attributes(self, saml_response, failures)`
+    - Arguments:
+      - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
+                          information from.
+      - `failures` - An `int` that represents the amount of times the returned
+                     mxid localpart mapping has failed.  This should be used
+                     to create a deduplicated mxid localpart which should be
+                     returned instead. For example, if this method returns
+                     `john.doe` as the value of `mxid_localpart` in the returned
+                     dict, and that is already taken on the homeserver, this
+                     method will be called again with the same parameters but
+                     with failures=1. The method should then return a different
+                     `mxid_localpart` value, such as `john.doe1`.
+    - This method must return a dictionary, which will then be used by Synapse
+      to build a new user. The following keys are allowed:
+       * `mxid_localpart` - Required. The mxid localpart of the new user.
+       * `displayname` - The displayname of the new user. If not provided, will default to
+                         the value of `mxid_localpart`.
+* `parse_config(config)`
+    - This method should have the `@staticmethod` decoration.
+    - Arguments:
+        - `config` - A `dict` representing the parsed content of the
+          `saml2_config.user_mapping_provider.config` homeserver config option.
+           Runs on homeserver startup. Providers should extract any option values
+           they need here.
+    - Whatever is returned will be passed back to the user mapping provider module's
+      `__init__` method during construction.
+* `get_saml_attributes(config)`
+    - This method should have the `@staticmethod` decoration.
+    - Arguments:
+        - `config` - A object resulting from a call to `parse_config`.
+    - Returns a tuple of two sets. The first set equates to the saml auth
+      response attributes that are required for the module to function, whereas
+      the second set consists of those attributes which can be used if available,
+      but are not necessary.
+
+## Synapse's Default Provider
+
+Synapse has a built-in SAML mapping provider if a custom provider isn't
+specified in the config. It is located at
+[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 10664ae8f7..4d44e631d1 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1250,33 +1250,58 @@ saml2_config:
   #
   #config_path: "CONFDIR/sp_conf.py"
 
-  # the lifetime of a SAML session. This defines how long a user has to
+  # The lifetime of a SAML session. This defines how long a user has to
   # complete the authentication process, if allow_unsolicited is unset.
   # The default is 5 minutes.
   #
   #saml_session_lifetime: 5m
 
-  # The SAML attribute (after mapping via the attribute maps) to use to derive
-  # the Matrix ID from. 'uid' by default.
+  # An external module can be provided here as a custom solution to
+  # mapping attributes returned from a saml provider onto a matrix user.
   #
-  #mxid_source_attribute: displayName
-
-  # The mapping system to use for mapping the saml attribute onto a matrix ID.
-  # Options include:
-  #  * 'hexencode' (which maps unpermitted characters to '=xx')
-  #  * 'dotreplace' (which replaces unpermitted characters with '.').
-  # The default is 'hexencode'.
-  #
-  #mxid_mapping: dotreplace
+  user_mapping_provider:
+    # The custom module's class. Uncomment to use a custom module.
+    #
+    #module: mapping_provider.SamlMappingProvider
 
-  # In previous versions of synapse, the mapping from SAML attribute to MXID was
-  # always calculated dynamically rather than stored in a table. For backwards-
-  # compatibility, we will look for user_ids matching such a pattern before
-  # creating a new account.
+    # Custom configuration values for the module. Below options are
+    # intended for the built-in provider, they should be changed if
+    # using a custom module. This section will be passed as a Python
+    # dictionary to the module's `parse_config` method.
+    #
+    config:
+      # The SAML attribute (after mapping via the attribute maps) to use
+      # to derive the Matrix ID from. 'uid' by default.
+      #
+      # Note: This used to be configured by the
+      # saml2_config.mxid_source_attribute option. If that is still
+      # defined, its value will be used instead.
+      #
+      #mxid_source_attribute: displayName
+
+      # The mapping system to use for mapping the saml attribute onto a
+      # matrix ID.
+      #
+      # Options include:
+      #  * 'hexencode' (which maps unpermitted characters to '=xx')
+      #  * 'dotreplace' (which replaces unpermitted characters with
+      #     '.').
+      # The default is 'hexencode'.
+      #
+      # Note: This used to be configured by the
+      # saml2_config.mxid_mapping option. If that is still defined, its
+      # value will be used instead.
+      #
+      #mxid_mapping: dotreplace
+
+  # In previous versions of synapse, the mapping from SAML attribute to
+  # MXID was always calculated dynamically rather than stored in a
+  # table. For backwards- compatibility, we will look for user_ids
+  # matching such a pattern before creating a new account.
   #
   # This setting controls the SAML attribute which will be used for this
-  # backwards-compatibility lookup. Typically it should be 'uid', but if the
-  # attribute maps are changed, it may be necessary to change it.
+  # backwards-compatibility lookup. Typically it should be 'uid', but if
+  # the attribute maps are changed, it may be necessary to change it.
   #
   # The default is 'uid'.
   #
diff --git a/mypy.ini b/mypy.ini
index 1d77c0ecc8..a66434b76b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,7 +1,7 @@
 [mypy]
 namespace_packages = True
 plugins = mypy_zope:plugin
-follow_imports = normal
+follow_imports = silent
 check_untyped_defs = True
 show_error_codes = True
 show_traceback = True
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
index 27a1ad1e7e..1776d202c5 100755
--- a/scripts-dev/update_database
+++ b/scripts-dev/update_database
@@ -58,10 +58,10 @@ if __name__ == "__main__":
             " on it."
         )
     )
-    parser.add_argument("-v", action='store_true')
+    parser.add_argument("-v", action="store_true")
     parser.add_argument(
         "--database-config",
-        type=argparse.FileType('r'),
+        type=argparse.FileType("r"),
         required=True,
         help="A database config file for either a SQLite3 database or a PostgreSQL one.",
     )
@@ -101,10 +101,7 @@ if __name__ == "__main__":
 
     # Instantiate and initialise the homeserver object.
     hs = MockHomeserver(
-        config,
-        database_engine,
-        db_conn,
-        db_config=config.database_config,
+        config, database_engine, db_conn, db_config=config.database_config,
     )
     # setup instantiates the store within the homeserver object.
     hs.setup()
@@ -112,13 +109,13 @@ if __name__ == "__main__":
 
     @defer.inlineCallbacks
     def run_background_updates():
-        yield store.run_background_updates(sleep=False)
+        yield store.db.updates.run_background_updates(sleep=False)
         # Stop the reactor to exit the script once every background update is run.
         reactor.stop()
 
     # Apply all background updates on the database.
-    reactor.callWhenRunning(lambda: run_as_background_process(
-        "background_updates", run_background_updates
-    ))
+    reactor.callWhenRunning(
+        lambda: run_as_background_process("background_updates", run_background_updates)
+    )
 
     reactor.run()
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index c4cf11d19a..e393a9b2f7 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -55,6 +55,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
 from synapse.storage.data_stores.main.user_directory import (
     UserDirectoryBackgroundUpdateStore,
 )
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util import Clock
@@ -139,48 +140,15 @@ class Store(
     UserDirectoryBackgroundUpdateStore,
     StatsStore,
 ):
-    def __init__(self, db_conn, hs):
-        super().__init__(db_conn, hs)
-        self.db_pool = hs.get_db_pool()
-
-    @defer.inlineCallbacks
-    def runInteraction(self, desc, func, *args, **kwargs):
-        def r(conn):
-            try:
-                i = 0
-                N = 5
-                while True:
-                    try:
-                        txn = conn.cursor()
-                        return func(
-                            LoggingTransaction(txn, desc, self.database_engine, [], []),
-                            *args,
-                            **kwargs
-                        )
-                    except self.database_engine.module.DatabaseError as e:
-                        if self.database_engine.is_deadlock(e):
-                            logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
-                            if i < N:
-                                i += 1
-                                conn.rollback()
-                                continue
-                        raise
-            except Exception as e:
-                logger.debug("[TXN FAIL] {%s} %s", desc, e)
-                raise
-
-        with PreserveLoggingContext():
-            return (yield self.db_pool.runWithConnection(r))
-
     def execute(self, f, *args, **kwargs):
-        return self.runInteraction(f.__name__, f, *args, **kwargs)
+        return self.db.runInteraction(f.__name__, f, *args, **kwargs)
 
     def execute_sql(self, sql, *args):
         def r(txn):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        return self.runInteraction("execute_sql", r)
+        return self.db.runInteraction("execute_sql", r)
 
     def insert_many_txn(self, txn, table, headers, rows):
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
@@ -223,7 +191,7 @@ class Porter(object):
     def setup_table(self, table):
         if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
-            row = yield self.postgres_store.simple_select_one(
+            row = yield self.postgres_store.db.simple_select_one(
                 table="port_from_sqlite3",
                 keyvalues={"table_name": table},
                 retcols=("forward_rowid", "backward_rowid"),
@@ -233,12 +201,14 @@ class Porter(object):
             total_to_port = None
             if row is None:
                 if table == "sent_transactions":
-                    forward_chunk, already_ported, total_to_port = (
-                        yield self._setup_sent_transactions()
-                    )
+                    (
+                        forward_chunk,
+                        already_ported,
+                        total_to_port,
+                    ) = yield self._setup_sent_transactions()
                     backward_chunk = 0
                 else:
-                    yield self.postgres_store.simple_insert(
+                    yield self.postgres_store.db.simple_insert(
                         table="port_from_sqlite3",
                         values={
                             "table_name": table,
@@ -268,7 +238,7 @@ class Porter(object):
 
             yield self.postgres_store.execute(delete_all)
 
-            yield self.postgres_store.simple_insert(
+            yield self.postgres_store.db.simple_insert(
                 table="port_from_sqlite3",
                 values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
             )
@@ -322,7 +292,7 @@ class Porter(object):
         if table == "user_directory_stream_pos":
             # We need to make sure there is a single row, `(X, null), as that is
             # what synapse expects to be there.
-            yield self.postgres_store.simple_insert(
+            yield self.postgres_store.db.simple_insert(
                 table=table, values={"stream_id": None}
             )
             self.progress.update(table, table_size)  # Mark table as done
@@ -363,7 +333,9 @@ class Porter(object):
 
                 return headers, forward_rows, backward_rows
 
-            headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
+            headers, frows, brows = yield self.sqlite_store.db.runInteraction(
+                "select", r
+            )
 
             if frows or brows:
                 if frows:
@@ -377,7 +349,7 @@ class Porter(object):
                 def insert(txn):
                     self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
 
-                    self.postgres_store.simple_update_one_txn(
+                    self.postgres_store.db.simple_update_one_txn(
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": table},
@@ -416,7 +388,7 @@ class Porter(object):
 
                 return headers, rows
 
-            headers, rows = yield self.sqlite_store.runInteraction("select", r)
+            headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
 
             if rows:
                 forward_chunk = rows[-1][0] + 1
@@ -433,8 +405,8 @@ class Porter(object):
                     rows_dict = []
                     for row in rows:
                         d = dict(zip(headers, row))
-                        if "\0" in d['value']:
-                            logger.warning('dropping search row %s', d)
+                        if "\0" in d["value"]:
+                            logger.warning("dropping search row %s", d)
                         else:
                             rows_dict.append(d)
 
@@ -454,7 +426,7 @@ class Porter(object):
                         ],
                     )
 
-                    self.postgres_store.simple_update_one_txn(
+                    self.postgres_store.db.simple_update_one_txn(
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": "event_search"},
@@ -504,17 +476,14 @@ class Porter(object):
         self.progress.set_state("Preparing %s" % config["name"])
         conn = self.setup_db(config, engine)
 
-        db_pool = adbapi.ConnectionPool(
-            config["name"], **config["args"]
-        )
+        db_pool = adbapi.ConnectionPool(config["name"], **config["args"])
 
         hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
 
-        store = Store(conn, hs)
+        store = Store(Database(hs), conn, hs)
 
-        yield store.runInteraction(
-            "%s_engine.check_database" % config["name"],
-            engine.check_database,
+        yield store.db.runInteraction(
+            "%s_engine.check_database" % config["name"], engine.check_database,
         )
 
         return store
@@ -522,7 +491,9 @@ class Porter(object):
     @defer.inlineCallbacks
     def run_background_updates_on_postgres(self):
         # Manually apply all background updates on the PostgreSQL database.
-        postgres_ready = yield self.postgres_store.has_completed_background_updates()
+        postgres_ready = (
+            yield self.postgres_store.db.updates.has_completed_background_updates()
+        )
 
         if not postgres_ready:
             # Only say that we're running background updates when there are background
@@ -530,9 +501,9 @@ class Porter(object):
             self.progress.set_state("Running background updates on PostgreSQL")
 
         while not postgres_ready:
-            yield self.postgres_store.do_next_background_update(100)
+            yield self.postgres_store.db.updates.do_next_background_update(100)
             postgres_ready = yield (
-                self.postgres_store.has_completed_background_updates()
+                self.postgres_store.db.updates.has_completed_background_updates()
             )
 
     @defer.inlineCallbacks
@@ -541,7 +512,9 @@ class Porter(object):
             self.sqlite_store = yield self.build_db_store(self.sqlite_config)
 
             # Check if all background updates are done, abort if not.
-            updates_complete = yield self.sqlite_store.has_completed_background_updates()
+            updates_complete = (
+                yield self.sqlite_store.db.updates.has_completed_background_updates()
+            )
             if not updates_complete:
                 sys.stderr.write(
                     "Pending background updates exist in the SQLite3 database."
@@ -582,22 +555,22 @@ class Porter(object):
                 )
 
             try:
-                yield self.postgres_store.runInteraction("alter_table", alter_table)
+                yield self.postgres_store.db.runInteraction("alter_table", alter_table)
             except Exception:
                 # On Error Resume Next
                 pass
 
-            yield self.postgres_store.runInteraction(
+            yield self.postgres_store.db.runInteraction(
                 "create_port_table", create_port_table
             )
 
             # Step 2. Get tables.
             self.progress.set_state("Fetching tables")
-            sqlite_tables = yield self.sqlite_store.simple_select_onecol(
+            sqlite_tables = yield self.sqlite_store.db.simple_select_onecol(
                 table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
             )
 
-            postgres_tables = yield self.postgres_store.simple_select_onecol(
+            postgres_tables = yield self.postgres_store.db.simple_select_onecol(
                 table="information_schema.tables",
                 keyvalues={},
                 retcol="distinct table_name",
@@ -687,11 +660,11 @@ class Porter(object):
             rows = txn.fetchall()
             headers = [column[0] for column in txn.description]
 
-            ts_ind = headers.index('ts')
+            ts_ind = headers.index("ts")
 
             return headers, [r for r in rows if r[ts_ind] < yesterday]
 
-        headers, rows = yield self.sqlite_store.runInteraction("select", r)
+        headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
 
         rows = self._convert_rows("sent_transactions", headers, rows)
 
@@ -724,7 +697,7 @@ class Porter(object):
         next_chunk = yield self.sqlite_store.execute(get_start_id)
         next_chunk = max(max_inserted_rowid + 1, next_chunk)
 
-        yield self.postgres_store.simple_insert(
+        yield self.postgres_store.db.simple_insert(
             table="port_from_sqlite3",
             values={
                 "table_name": "sent_transactions",
@@ -737,7 +710,7 @@ class Porter(object):
             txn.execute(
                 "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
             )
-            size, = txn.fetchone()
+            (size,) = txn.fetchone()
             return int(size)
 
         remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
@@ -790,7 +763,7 @@ class Porter(object):
             next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
 
-        return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
+        return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
 
 
 ##############################################
@@ -871,7 +844,7 @@ class CursesProgress(Progress):
         duration = int(now) - int(self.start_time)
 
         minutes, seconds = divmod(duration, 60)
-        duration_str = '%02dm %02ds' % (minutes, seconds)
+        duration_str = "%02dm %02ds" % (minutes, seconds)
 
         if self.finished:
             status = "Time spent: %s (Done!)" % (duration_str,)
@@ -881,7 +854,7 @@ class CursesProgress(Progress):
                 left = float(self.total_remaining) / self.total_processed
 
                 est_remaining = (int(now) - self.start_time) * left
-                est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
+                est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60)
             else:
                 est_remaining_str = "Unknown"
             status = "Time spent: %s (est. remaining: %s)" % (
@@ -967,7 +940,7 @@ if __name__ == "__main__":
         description="A script to port an existing synapse SQLite database to"
         " a new PostgreSQL database."
     )
-    parser.add_argument("-v", action='store_true')
+    parser.add_argument("-v", action="store_true")
     parser.add_argument(
         "--sqlite-database",
         required=True,
@@ -976,12 +949,12 @@ if __name__ == "__main__":
     )
     parser.add_argument(
         "--postgres-config",
-        type=argparse.FileType('r'),
+        type=argparse.FileType("r"),
         required=True,
         help="The database config file for the PostgreSQL database",
     )
     parser.add_argument(
-        "--curses", action='store_true', help="display a curses based progress UI"
+        "--curses", action="store_true", help="display a curses based progress UI"
     )
 
     parser.add_argument(
diff --git a/synapse/__init__.py b/synapse/__init__.py
index f99de2f3f3..c67a51a8d5 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.6.1"
+__version__ = "1.7.0rc1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 2ac7d5c064..9c96816096 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -269,7 +269,7 @@ def start(hs, listeners=None):
 
         # It is now safe to start your Synapse.
         hs.start_listening(listeners)
-        hs.get_datastore().start_profiling()
+        hs.get_datastore().db.start_profiling()
 
         setup_sentry(hs)
         setup_sdnotify(hs)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 448e45e00f..f24920a7d6 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.replication.tcp.streams._base import ReceiptsStream
 from synapse.server import HomeServer
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.types import ReadReceipt
 from synapse.util.async_helpers import Linearizer
@@ -59,8 +60,8 @@ class FederationSenderSlaveStore(
     SlavedDeviceStore,
     SlavedPresenceStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs)
 
         # We pull out the current federation stream position now so that we
         # always have a known value for the federation position in memory so
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 267aebaae9..032010600a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -68,9 +68,9 @@ from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.rest.media.v0.content_repository import ContentRepoResource
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
-from synapse.storage import DataStore, are_all_users_on_domain
+from synapse.storage import DataStore
 from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
-from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.storage.prepare_database import UpgradeDatabaseException
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
@@ -294,22 +294,6 @@ class SynapseHomeServer(HomeServer):
             else:
                 logger.warning("Unrecognized listener type: %s", listener["type"])
 
-    def run_startup_checks(self, db_conn, database_engine):
-        all_users_native = are_all_users_on_domain(
-            db_conn.cursor(), database_engine, self.hostname
-        )
-        if not all_users_native:
-            quit_with_error(
-                "Found users in database not native to %s!\n"
-                "You cannot changed a synapse server_name after it's been configured"
-                % (self.hostname,)
-            )
-
-        try:
-            database_engine.check_database(db_conn.cursor())
-        except IncorrectDatabaseSetup as e:
-            quit_with_error(str(e))
-
 
 # Gauges to expose monthly active user control metrics
 current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
@@ -357,16 +341,12 @@ def setup(config_options):
 
     synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
 
-    logger.info("Preparing database: %s...", config.database_config["name"])
+    logger.info("Setting up server")
 
     try:
-        with hs.get_db_conn(run_new_connection=False) as db_conn:
-            prepare_database(db_conn, database_engine, config=config)
-            database_engine.on_new_connection(db_conn)
-
-            hs.run_startup_checks(db_conn, database_engine)
-
-            db_conn.commit()
+        hs.setup()
+    except IncorrectDatabaseSetup as e:
+        quit_with_error(str(e))
     except UpgradeDatabaseException:
         sys.stderr.write(
             "\nFailed to upgrade database.\n"
@@ -375,9 +355,6 @@ def setup(config_options):
         )
         sys.exit(1)
 
-    logger.info("Database prepared in %s.", config.database_config["name"])
-
-    hs.setup()
     hs.setup_master()
 
     @defer.inlineCallbacks
@@ -436,7 +413,7 @@ def setup(config_options):
             _base.start(hs, config.listeners)
 
             hs.get_pusherpool().start()
-            hs.get_datastore().start_doing_background_updates()
+            hs.get_datastore().db.updates.start_doing_background_updates()
         except Exception:
             # Print the exception and bail out.
             print("Error during startup:", file=sys.stderr)
@@ -542,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
     # Database version
     #
 
-    stats["database_engine"] = hs.database_engine.module.__name__
-    stats["database_server_version"] = hs.database_engine.server_version
+    # This only reports info about the *main* database.
+    stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+    stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
     logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
     try:
         yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 0fa2b50999..c01fb34a9b 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import (
 from synapse.rest.client.v2_alpha import user_directory
 from synapse.server import HomeServer
 from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.httpresourcetree import create_resource_tree
@@ -60,11 +61,11 @@ class UserDirectorySlaveStore(
     UserDirectoryStore,
     BaseSlavedStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
-        curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict(
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
             db_conn,
             "current_state_delta_stream",
             entity_column="room_id",
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
+import logging
 
 from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
-    map_username_to_mxid_localpart,
-    mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
 
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+    "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
 
 def _dict_merge(merge_dict, into_dict):
     """Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
-        self.saml2_mxid_source_attribute = saml2_config.get(
-            "mxid_source_attribute", "uid"
-        )
-
         self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
             "grandfathered_mxid_source_attribute", "uid"
         )
 
-        saml2_config_dict = self._default_saml_config_dict()
+        # user_mapping_provider may be None if the key is present but has no value
+        ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+        # Use the default user mapping provider if not set
+        ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+        # Ensure a config is present
+        ump_dict["config"] = ump_dict.get("config") or {}
+
+        if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+            # Load deprecated options for use by the default module
+            old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+            if old_mxid_source_attribute:
+                logger.warning(
+                    "The config option saml2_config.mxid_source_attribute is deprecated. "
+                    "Please use saml2_config.user_mapping_provider.config"
+                    ".mxid_source_attribute instead."
+                )
+                ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+            old_mxid_mapping = saml2_config.get("mxid_mapping")
+            if old_mxid_mapping:
+                logger.warning(
+                    "The config option saml2_config.mxid_mapping is deprecated. Please "
+                    "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+                )
+                ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+        # Retrieve an instance of the module's class
+        # Pass the config dictionary to the module for processing
+        (
+            self.saml2_user_mapping_provider_class,
+            self.saml2_user_mapping_provider_config,
+        ) = load_module(ump_dict)
+
+        # Ensure loaded user mapping module has defined all necessary methods
+        # Note parse_config() is already checked during the call to load_module
+        required_methods = [
+            "get_saml_attributes",
+            "saml_response_to_user_attributes",
+        ]
+        missing_methods = [
+            method
+            for method in required_methods
+            if not hasattr(self.saml2_user_mapping_provider_class, method)
+        ]
+        if missing_methods:
+            raise ConfigError(
+                "Class specified by saml2_config."
+                "user_mapping_provider.module is missing required "
+                "methods: %s" % (", ".join(missing_methods),)
+            )
+
+        # Get the desired saml auth response attributes from the module
+        saml2_config_dict = self._default_saml_config_dict(
+            *self.saml2_user_mapping_provider_class.get_saml_attributes(
+                self.saml2_user_mapping_provider_config
+            )
+        )
         _dict_merge(
             merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
         )
@@ -103,22 +159,27 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "5m")
         )
 
-        mapping = saml2_config.get("mxid_mapping", "hexencode")
-        try:
-            self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
-        except KeyError:
-            raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
-    def _default_saml_config_dict(self):
+    def _default_saml_config_dict(
+        self, required_attributes: set, optional_attributes: set
+    ):
+        """Generate a configuration dictionary with required and optional attributes that
+        will be needed to process new user registration
+
+        Args:
+            required_attributes: SAML auth response attributes that are
+                necessary to function
+            optional_attributes: SAML auth response attributes that can be used to add
+                additional information to Synapse user accounts, but are not required
+
+        Returns:
+            dict: A SAML configuration dictionary
+        """
         import saml2
 
         public_baseurl = self.public_baseurl
         if public_baseurl is None:
             raise ConfigError("saml2_config requires a public_baseurl to be set")
 
-        required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
-        optional_attributes = {"displayName"}
         if self.saml2_grandfathered_mxid_source_attribute:
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
         optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
           #
           #config_path: "%(config_dir_path)s/sp_conf.py"
 
-          # the lifetime of a SAML session. This defines how long a user has to
+          # The lifetime of a SAML session. This defines how long a user has to
           # complete the authentication process, if allow_unsolicited is unset.
           # The default is 5 minutes.
           #
           #saml_session_lifetime: 5m
 
-          # The SAML attribute (after mapping via the attribute maps) to use to derive
-          # the Matrix ID from. 'uid' by default.
+          # An external module can be provided here as a custom solution to
+          # mapping attributes returned from a saml provider onto a matrix user.
           #
-          #mxid_source_attribute: displayName
-
-          # The mapping system to use for mapping the saml attribute onto a matrix ID.
-          # Options include:
-          #  * 'hexencode' (which maps unpermitted characters to '=xx')
-          #  * 'dotreplace' (which replaces unpermitted characters with '.').
-          # The default is 'hexencode'.
-          #
-          #mxid_mapping: dotreplace
-
-          # In previous versions of synapse, the mapping from SAML attribute to MXID was
-          # always calculated dynamically rather than stored in a table. For backwards-
-          # compatibility, we will look for user_ids matching such a pattern before
-          # creating a new account.
+          user_mapping_provider:
+            # The custom module's class. Uncomment to use a custom module.
+            #
+            #module: mapping_provider.SamlMappingProvider
+
+            # Custom configuration values for the module. Below options are
+            # intended for the built-in provider, they should be changed if
+            # using a custom module. This section will be passed as a Python
+            # dictionary to the module's `parse_config` method.
+            #
+            config:
+              # The SAML attribute (after mapping via the attribute maps) to use
+              # to derive the Matrix ID from. 'uid' by default.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_source_attribute option. If that is still
+              # defined, its value will be used instead.
+              #
+              #mxid_source_attribute: displayName
+
+              # The mapping system to use for mapping the saml attribute onto a
+              # matrix ID.
+              #
+              # Options include:
+              #  * 'hexencode' (which maps unpermitted characters to '=xx')
+              #  * 'dotreplace' (which replaces unpermitted characters with
+              #     '.').
+              # The default is 'hexencode'.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_mapping option. If that is still defined, its
+              # value will be used instead.
+              #
+              #mxid_mapping: dotreplace
+
+          # In previous versions of synapse, the mapping from SAML attribute to
+          # MXID was always calculated dynamically rather than stored in a
+          # table. For backwards- compatibility, we will look for user_ids
+          # matching such a pattern before creating a new account.
           #
           # This setting controls the SAML attribute which will be used for this
-          # backwards-compatibility lookup. Typically it should be 'uid', but if the
-          # attribute maps are changed, it may be necessary to change it.
+          # backwards-compatibility lookup. Typically it should be 'uid', but if
+          # the attribute maps are changed, it may be necessary to change it.
           #
           # The default is 'uid'.
           #
@@ -241,23 +327,3 @@ class SAML2Config(Config):
         """ % {
             "config_dir_path": config_dir_path
         }
-
-
-DOT_REPLACE_PATTERN = re.compile(
-    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
-    username = username.lower()
-    username = DOT_REPLACE_PATTERN.sub(".", username)
-
-    # regular mxids aren't allowed to start with an underscore either
-    username = re.sub("^_", "", username)
-    return username
-
-
-MXID_MAPPER_MAP = {
-    "hexencode": map_username_to_mxid_localpart,
-    "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..c940b84470 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
     Returns:
          if the auth checks pass.
     """
+    assert isinstance(auth_events, dict)
+
     if do_size_check:
         _check_size_limits(event)
 
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
-    if auth_events is None:
-        # Oh, we don't know what the state of the room was, so we
-        # are trusting that this is allowed (at least for now)
-        logger.warning("Trusting event: %s", event.event_id)
-        return
-
     if event.type == EventTypes.Create:
         sender_domain = get_domain_from_id(event.sender)
         room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 890a201a5a..af652a7659 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
 import itertools
 import logging
 
-from six.moves import range
-
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.utils import log_function
 from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,162 +308,26 @@ class FederationClient(FederationBase):
         return signed_pdu
 
     @defer.inlineCallbacks
-    @log_function
-    def get_state_for_room(self, destination, room_id, event_id):
-        """Requests all of the room state at a given event from a remote homeserver.
-
-        Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+    def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+        """Calls the /state_ids endpoint to fetch the state at a particular point
+        in the room, and the auth events for the given event
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            Tuple[List[str], List[str]]:  a tuple of (state event_ids, auth event_ids)
         """
-        try:
-            # First we try and ask for just the IDs, as thats far quicker if
-            # we have most of the state and auth_chain already.
-            # However, this may 404 if the other side has an old synapse.
-            result = yield self.transport_layer.get_room_state_ids(
-                destination, room_id, event_id=event_id
-            )
-
-            state_event_ids = result["pdu_ids"]
-            auth_event_ids = result.get("auth_chain_ids", [])
-
-            fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-                destination, room_id, set(state_event_ids + auth_event_ids)
-            )
-
-            if failed_to_fetch:
-                logger.warning(
-                    "Failed to fetch missing state/auth events for %s: %s",
-                    room_id,
-                    failed_to_fetch,
-                )
-
-            event_map = {ev.event_id: ev for ev in fetched_events}
-
-            pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-            auth_chain = [
-                event_map[e_id] for e_id in auth_event_ids if e_id in event_map
-            ]
-
-            auth_chain.sort(key=lambda e: e.depth)
-
-            return pdus, auth_chain
-        except HttpResponseException as e:
-            if e.code == 400 or e.code == 404:
-                logger.info("Failed to use get_room_state_ids API, falling back")
-            else:
-                raise e
-
-        result = yield self.transport_layer.get_room_state(
+        result = yield self.transport_layer.get_room_state_ids(
             destination, room_id, event_id=event_id
         )
 
-        room_version = yield self.store.get_room_version(room_id)
-        format_ver = room_version_to_event_format(room_version)
-
-        pdus = [
-            event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"]
-        ]
-
-        auth_chain = [
-            event_from_pdu_json(p, format_ver, outlier=True)
-            for p in result.get("auth_chain", [])
-        ]
-
-        seen_events = yield self.store.get_events(
-            [ev.event_id for ev in itertools.chain(pdus, auth_chain)]
-        )
-
-        signed_pdus = yield self._check_sigs_and_hash_and_fetch(
-            destination,
-            [p for p in pdus if p.event_id not in seen_events],
-            outlier=True,
-            room_version=room_version,
-        )
-        signed_pdus.extend(
-            seen_events[p.event_id] for p in pdus if p.event_id in seen_events
-        )
-
-        signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination,
-            [p for p in auth_chain if p.event_id not in seen_events],
-            outlier=True,
-            room_version=room_version,
-        )
-        signed_auth.extend(
-            seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
-        )
-
-        signed_auth.sort(key=lambda e: e.depth)
-
-        return signed_pdus, signed_auth
-
-    @defer.inlineCallbacks
-    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
-        """Fetch events from a remote destination, checking if we already have them.
-
-        Args:
-            destination (str)
-            room_id (str)
-            event_ids (list)
-
-        Returns:
-            Deferred: A deferred resolving to a 2-tuple where the first is a list of
-            events and the second is a list of event ids that we failed to fetch.
-        """
-        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
-        signed_events = list(seen_events.values())
-
-        failed_to_fetch = set()
-
-        missing_events = set(event_ids)
-        for k in seen_events:
-            missing_events.discard(k)
-
-        if not missing_events:
-            return signed_events, failed_to_fetch
-
-        logger.debug(
-            "Fetching unknown state/auth events %s for room %s",
-            missing_events,
-            event_ids,
-        )
-
-        room_version = yield self.store.get_room_version(room_id)
-
-        batch_size = 20
-        missing_events = list(missing_events)
-        for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i : i + batch_size])
-
-            deferreds = [
-                run_in_background(
-                    self.get_pdu,
-                    destinations=[destination],
-                    event_id=e_id,
-                    room_version=room_version,
-                )
-                for e_id in batch
-            ]
-
-            res = yield make_deferred_yieldable(
-                defer.DeferredList(deferreds, consumeErrors=True)
-            )
-            for success, result in res:
-                if success and result:
-                    signed_events.append(result)
-                    batch.discard(result.event_id)
+        state_event_ids = result["pdu_ids"]
+        auth_event_ids = result.get("auth_chain_ids", [])
 
-            # We removed all events we successfully fetched from `batch`
-            failed_to_fetch.update(batch)
+        if not isinstance(state_event_ids, list) or not isinstance(
+            auth_event_ids, list
+        ):
+            raise Exception("invalid response from /state_ids")
 
-        return signed_events, failed_to_fetch
+        return state_event_ids, auth_event_ids
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bd7fb81995..198257414b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -39,30 +39,6 @@ class TransportLayerClient(object):
         self.client = hs.get_http_client()
 
     @log_function
-    def get_room_state(self, destination, room_id, event_id):
-        """ Requests all state for a given room from the given server at the
-        given event.
-
-        Args:
-            destination (str): The host name of the remote homeserver we want
-                to get the state from.
-            context (str): The name of the context we want the state of
-            event_id (str): The event we want the context at.
-
-        Returns:
-            Deferred: Results in a dict received from the remote homeserver.
-        """
-        logger.debug("get_room_state dest=%s, room=%s", destination, room_id)
-
-        path = _create_v1_path("/state/%s", room_id)
-        return self.client.get_json(
-            destination,
-            path=path,
-            args={"event_id": event_id},
-            try_trailing_slash_on_400=True,
-        )
-
-    @log_function
     def get_room_state_ids(self, destination, room_id, event_id):
         """ Requests all state for a given room from the given server at the
         given event. Returns the state's event_id's
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 28c12753c1..57a10daefd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,7 +264,6 @@ class E2eKeysHandler(object):
 
         return ret
 
-    @defer.inlineCallbacks
     def get_cross_signing_keys_from_cache(self, query, from_user_id):
         """Get cross-signing keys for users from the database
 
@@ -284,35 +283,14 @@ class E2eKeysHandler(object):
         self_signing_keys = {}
         user_signing_keys = {}
 
-        for user_id in query:
-            # XXX: consider changing the store functions to allow querying
-            # multiple users simultaneously.
-            key = yield self.store.get_e2e_cross_signing_key(
-                user_id, "master", from_user_id
-            )
-            if key:
-                master_keys[user_id] = key
-
-            key = yield self.store.get_e2e_cross_signing_key(
-                user_id, "self_signing", from_user_id
-            )
-            if key:
-                self_signing_keys[user_id] = key
-
-            # users can see other users' master and self-signing keys, but can
-            # only see their own user-signing keys
-            if from_user_id == user_id:
-                key = yield self.store.get_e2e_cross_signing_key(
-                    user_id, "user_signing", from_user_id
-                )
-                if key:
-                    user_signing_keys[user_id] = key
-
-        return {
-            "master_keys": master_keys,
-            "self_signing_keys": self_signing_keys,
-            "user_signing_keys": user_signing_keys,
-        }
+        # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
+        return defer.succeed(
+            {
+                "master_keys": master_keys,
+                "self_signing_keys": self_signing_keys,
+                "user_signing_keys": user_signing_keys,
+            }
+        )
 
     @trace
     @defer.inlineCallbacks
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 45fe13c62f..ec18a42a68 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -16,8 +16,6 @@
 import logging
 import random
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
@@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler):
         self._server_notices_sender = hs.get_server_notices_sender()
         self._event_serializer = hs.get_event_client_serializer()
 
-    @defer.inlineCallbacks
     @log_function
-    def get_stream(
+    async def get_stream(
         self,
         auth_user_id,
         pagin_config,
@@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler):
         """
 
         if room_id:
-            blocked = yield self.store.is_room_blocked(room_id)
+            blocked = await self.store.is_room_blocked(room_id)
             if blocked:
                 raise SynapseError(403, "This room has been blocked on this server")
 
         # send any outstanding server notices to the user.
-        yield self._server_notices_sender.on_user_syncing(auth_user_id)
+        await self._server_notices_sender.on_user_syncing(auth_user_id)
 
         auth_user = UserID.from_string(auth_user_id)
         presence_handler = self.hs.get_presence_handler()
 
-        context = yield presence_handler.user_syncing(
+        context = await presence_handler.user_syncing(
             auth_user_id, affect_presence=affect_presence
         )
         with context:
@@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler):
                 # thundering herds on restart.
                 timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
 
-            events, tokens = yield self.notifier.get_events_for(
+            events, tokens = await self.notifier.get_events_for(
                 auth_user,
                 pagin_config,
                 timeout,
@@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = yield self.state.get_current_users_in_room(
+                        users = await self.state.get_current_users_in_room(
                             event.room_id
                         )
-                        states = yield presence_handler.get_states(users, as_event=True)
+                        states = await presence_handler.get_states(users, as_event=True)
                         to_add.extend(states)
                     else:
 
-                        ev = yield presence_handler.get_state(
+                        ev = await presence_handler.get_state(
                             UserID.from_string(event.state_key), as_event=True
                         )
                         to_add.append(ev)
@@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler):
 
             time_now = self.clock.time_msec()
 
-            chunks = yield self._event_serializer.serialize_events(
+            chunks = await self._event_serializer.serialize_events(
                 events,
                 time_now,
                 as_client_event=as_client_event,
@@ -151,8 +148,7 @@ class EventHandler(BaseHandler):
         super(EventHandler, self).__init__(hs)
         self.storage = hs.get_storage()
 
-    @defer.inlineCallbacks
-    def get_event(self, user, room_id, event_id):
+    async def get_event(self, user, room_id, event_id):
         """Retrieve a single specified event.
 
         Args:
@@ -167,15 +163,15 @@ class EventHandler(BaseHandler):
             AuthError if the user does not have the rights to inspect this
             event.
         """
-        event = yield self.store.get_event(event_id, check_room_id=room_id)
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         if not event:
             return None
 
-        users = yield self.store.get_users_in_room(event.room_id)
+        users = await self.store.get_users_in_room(event.room_id)
         is_peeking = user.to_string() not in users
 
-        filtered = yield filter_events_for_client(
+        filtered = await filter_events_for_client(
             self.storage, user.to_string(), [event], is_peeking=is_peeking
         )
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..c0dcf9abf8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -379,11 +379,9 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self.federation_client.get_state_for_room(
-                                origin, room_id, p
-                            )
+                            ) = yield self._get_state_for_room(origin, room_id, p)
 
-                            # we want the state *after* p; get_state_for_room returns the
+                            # we want the state *after* p; _get_state_for_room returns the
                             # state *before* p.
                             remote_event = yield self.federation_client.get_pdu(
                                 [origin], p, room_version, outlier=True
@@ -584,6 +582,97 @@ class FederationHandler(BaseHandler):
                         raise
 
     @defer.inlineCallbacks
+    @log_function
+    def _get_state_for_room(self, destination, room_id, event_id):
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination (str): The remote homeserver to query for the state.
+            room_id (str): The id of the room we're interested in.
+            event_id (str): The id of the event we want the state at.
+
+        Returns:
+            Deferred[Tuple[List[EventBase], List[EventBase]]]:
+                A list of events in the state, and a list of events in the auth chain
+                for the given event.
+        """
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = yield self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
+
+        desired_events = set(state_event_ids + auth_event_ids)
+        event_map = yield self._get_events_from_store_or_dest(
+            destination, room_id, desired_events
+        )
+
+        failed_to_fetch = desired_events - event_map.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s: %s",
+                room_id,
+                failed_to_fetch,
+            )
+
+        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+
+        auth_chain.sort(key=lambda e: e.depth)
+
+        return pdus, auth_chain
+
+    @defer.inlineCallbacks
+    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+        """Fetch events from a remote destination, checking if we already have them.
+
+        Args:
+            destination (str)
+            room_id (str)
+            event_ids (Iterable[str])
+
+        Returns:
+            Deferred[dict[str, EventBase]]: A deferred resolving to a map
+            from event_id to event
+        """
+        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+
+        missing_events = set(event_ids) - fetched_events.keys()
+
+        if not missing_events:
+            return fetched_events
+
+        logger.debug(
+            "Fetching unknown state/auth events %s for room %s",
+            missing_events,
+            event_ids,
+        )
+
+        room_version = yield self.store.get_room_version(room_id)
+
+        # XXX 20 requests at once? really?
+        for batch in batch_iter(missing_events, 20):
+            deferreds = [
+                run_in_background(
+                    self.federation_client.get_pdu,
+                    destinations=[destination],
+                    event_id=e_id,
+                    room_version=room_version,
+                )
+                for e_id in batch
+            ]
+
+            res = yield make_deferred_yieldable(
+                defer.DeferredList(deferreds, consumeErrors=True)
+            )
+            for success, result in res:
+                if success and result:
+                    fetched_events[result.event_id] = result
+
+        return fetched_events
+
+    @defer.inlineCallbacks
     def _process_received_pdu(self, origin, event, state, auth_chain):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
@@ -723,7 +812,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.federation_client.get_state_for_room(
+            state, auth = yield self._get_state_for_room(
                 destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..73c110a92b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
+        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler):
             as_client_event,
             include_archived,
         )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
 
-        return self.snapshot_cache.set(
-            now_ms,
+        return self.snapshot_cache.wrap(
             key,
-            self._snapshot_all_rooms(
-                user_id, pagin_config, as_client_event, include_archived
-            ),
+            self._snapshot_all_rooms,
+            user_id,
+            pagin_config,
+            as_client_event,
+            include_archived,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4f53a5f5dc..54fa216d83 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -363,6 +363,8 @@ class EventCreationHandler(object):
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
 
+        self.room_invite_state_types = self.hs.config.room_invite_state_types
+
         self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
 
         # This is only used to get at ratelimit function, and maybe_kick_guest_users
@@ -916,7 +918,7 @@ class EventCreationHandler(object):
                 state_to_include_ids = [
                     e_id
                     for k, e_id in iteritems(current_state_ids)
-                    if k[0] in self.hs.config.room_invite_state_types
+                    if k[0] in self.room_invite_state_types
                     or k == (EventTypes.Member, event.sender)
                 ]
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b536d410e5..2d3b8ba73c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -22,8 +22,6 @@ from six import iteritems, itervalues
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.logging.context import LoggingContext
 from synapse.push.clientformat import format_push_rules_for_user
@@ -241,8 +239,7 @@ class SyncHandler(object):
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
         )
 
-    @defer.inlineCallbacks
-    def wait_for_sync_for_user(
+    async def wait_for_sync_for_user(
         self, sync_config, since_token=None, timeout=0, full_state=False
     ):
         """Get the sync for a client if we have new data for it now. Otherwise
@@ -255,9 +252,9 @@ class SyncHandler(object):
         # not been exceeded (if not part of the group by this point, almost certain
         # auth_blocking will occur)
         user_id = sync_config.user.to_string()
-        yield self.auth.check_auth_blocking(user_id)
+        await self.auth.check_auth_blocking(user_id)
 
-        res = yield self.response_cache.wrap(
+        res = await self.response_cache.wrap(
             sync_config.request_key,
             self._wait_for_sync_for_user,
             sync_config,
@@ -267,8 +264,9 @@ class SyncHandler(object):
         )
         return res
 
-    @defer.inlineCallbacks
-    def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
+    async def _wait_for_sync_for_user(
+        self, sync_config, since_token, timeout, full_state
+    ):
         if since_token is None:
             sync_type = "initial_sync"
         elif full_state:
@@ -283,7 +281,7 @@ class SyncHandler(object):
         if timeout == 0 or since_token is None or full_state:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
-            result = yield self.current_sync_for_user(
+            result = await self.current_sync_for_user(
                 sync_config, since_token, full_state=full_state
             )
         else:
@@ -291,7 +289,7 @@ class SyncHandler(object):
             def current_sync_callback(before_token, after_token):
                 return self.current_sync_for_user(sync_config, since_token)
 
-            result = yield self.notifier.wait_for_events(
+            result = await self.notifier.wait_for_events(
                 sync_config.user.to_string(),
                 timeout,
                 current_sync_callback,
@@ -314,15 +312,13 @@ class SyncHandler(object):
         """
         return self.generate_sync_result(sync_config, since_token, full_state)
 
-    @defer.inlineCallbacks
-    def push_rules_for_user(self, user):
+    async def push_rules_for_user(self, user):
         user_id = user.to_string()
-        rules = yield self.store.get_push_rules_for_user(user_id)
+        rules = await self.store.get_push_rules_for_user(user_id)
         rules = format_push_rules_for_user(user, rules)
         return rules
 
-    @defer.inlineCallbacks
-    def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
+    async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
         """Get the ephemeral events for each room the user is in
         Args:
             sync_result_builder(SyncResultBuilder)
@@ -343,7 +339,7 @@ class SyncHandler(object):
             room_ids = sync_result_builder.joined_room_ids
 
             typing_source = self.event_sources.sources["typing"]
-            typing, typing_key = yield typing_source.get_new_events(
+            typing, typing_key = await typing_source.get_new_events(
                 user=sync_config.user,
                 from_key=typing_key,
                 limit=sync_config.filter_collection.ephemeral_limit(),
@@ -365,7 +361,7 @@ class SyncHandler(object):
             receipt_key = since_token.receipt_key if since_token else "0"
 
             receipt_source = self.event_sources.sources["receipt"]
-            receipts, receipt_key = yield receipt_source.get_new_events(
+            receipts, receipt_key = await receipt_source.get_new_events(
                 user=sync_config.user,
                 from_key=receipt_key,
                 limit=sync_config.filter_collection.ephemeral_limit(),
@@ -382,8 +378,7 @@ class SyncHandler(object):
 
         return now_token, ephemeral_by_room
 
-    @defer.inlineCallbacks
-    def _load_filtered_recents(
+    async def _load_filtered_recents(
         self,
         room_id,
         sync_config,
@@ -415,10 +410,10 @@ class SyncHandler(object):
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in recents):
-                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = await self.state.get_current_state_ids(room_id)
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
-                recents = yield filter_events_for_client(
+                recents = await filter_events_for_client(
                     self.storage,
                     sync_config.user.to_string(),
                     recents,
@@ -449,14 +444,14 @@ class SyncHandler(object):
                 # Otherwise, we want to return the last N events in the room
                 # in toplogical ordering.
                 if since_key:
-                    events, end_key = yield self.store.get_room_events_stream_for_room(
+                    events, end_key = await self.store.get_room_events_stream_for_room(
                         room_id,
                         limit=load_limit + 1,
                         from_key=since_key,
                         to_key=end_key,
                     )
                 else:
-                    events, end_key = yield self.store.get_recent_events_for_room(
+                    events, end_key = await self.store.get_recent_events_for_room(
                         room_id, limit=load_limit + 1, end_token=end_key
                     )
                 loaded_recents = sync_config.filter_collection.filter_room_timeline(
@@ -468,10 +463,10 @@ class SyncHandler(object):
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in loaded_recents):
-                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = await self.state.get_current_state_ids(room_id)
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
-                loaded_recents = yield filter_events_for_client(
+                loaded_recents = await filter_events_for_client(
                     self.storage,
                     sync_config.user.to_string(),
                     loaded_recents,
@@ -498,8 +493,7 @@ class SyncHandler(object):
             limited=limited or newly_joined_room,
         )
 
-    @defer.inlineCallbacks
-    def get_state_after_event(self, event, state_filter=StateFilter.all()):
+    async def get_state_after_event(self, event, state_filter=StateFilter.all()):
         """
         Get the room state after the given event
 
@@ -511,7 +505,7 @@ class SyncHandler(object):
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
-        state_ids = yield self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_store.get_state_ids_for_event(
             event.event_id, state_filter=state_filter
         )
         if event.is_state():
@@ -519,8 +513,9 @@ class SyncHandler(object):
             state_ids[(event.type, event.state_key)] = event.event_id
         return state_ids
 
-    @defer.inlineCallbacks
-    def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
+    async def get_state_at(
+        self, room_id, stream_position, state_filter=StateFilter.all()
+    ):
         """ Get the room state at a particular stream position
 
         Args:
@@ -536,13 +531,13 @@ class SyncHandler(object):
         # get_recent_events_for_room operates by topo ordering. This therefore
         # does not reliably give you the state at the given stream position.
         # (https://github.com/matrix-org/synapse/issues/3305)
-        last_events, _ = yield self.store.get_recent_events_for_room(
+        last_events, _ = await self.store.get_recent_events_for_room(
             room_id, end_token=stream_position.room_key, limit=1
         )
 
         if last_events:
             last_event = last_events[-1]
-            state = yield self.get_state_after_event(
+            state = await self.get_state_after_event(
                 last_event, state_filter=state_filter
             )
 
@@ -551,8 +546,7 @@ class SyncHandler(object):
             state = {}
         return state
 
-    @defer.inlineCallbacks
-    def compute_summary(self, room_id, sync_config, batch, state, now_token):
+    async def compute_summary(self, room_id, sync_config, batch, state, now_token):
         """ Works out a room summary block for this room, summarising the number
         of joined members in the room, and providing the 'hero' members if the
         room has no name so clients can consistently name rooms.  Also adds
@@ -574,7 +568,7 @@ class SyncHandler(object):
         # FIXME: we could/should get this from room_stats when matthew/stats lands
 
         # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
-        last_events, _ = yield self.store.get_recent_event_ids_for_room(
+        last_events, _ = await self.store.get_recent_event_ids_for_room(
             room_id, end_token=now_token.room_key, limit=1
         )
 
@@ -582,7 +576,7 @@ class SyncHandler(object):
             return None
 
         last_event = last_events[-1]
-        state_ids = yield self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_store.get_state_ids_for_event(
             last_event.event_id,
             state_filter=StateFilter.from_types(
                 [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -590,7 +584,7 @@ class SyncHandler(object):
         )
 
         # this is heavily cached, thus: fast.
-        details = yield self.store.get_room_summary(room_id)
+        details = await self.store.get_room_summary(room_id)
 
         name_id = state_ids.get((EventTypes.Name, ""))
         canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
@@ -608,12 +602,12 @@ class SyncHandler(object):
         # calculating heroes. Empty strings are falsey, so we check
         # for the "name" value and default to an empty string.
         if name_id:
-            name = yield self.store.get_event(name_id, allow_none=True)
+            name = await self.store.get_event(name_id, allow_none=True)
             if name and name.content.get("name"):
                 return summary
 
         if canonical_alias_id:
-            canonical_alias = yield self.store.get_event(
+            canonical_alias = await self.store.get_event(
                 canonical_alias_id, allow_none=True
             )
             if canonical_alias and canonical_alias.content.get("alias"):
@@ -678,7 +672,7 @@ class SyncHandler(object):
             )
         ]
 
-        missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
+        missing_hero_state = await self.store.get_events(missing_hero_event_ids)
         missing_hero_state = missing_hero_state.values()
 
         for s in missing_hero_state:
@@ -697,8 +691,7 @@ class SyncHandler(object):
             logger.debug("found LruCache for %r", cache_key)
         return cache
 
-    @defer.inlineCallbacks
-    def compute_state_delta(
+    async def compute_state_delta(
         self, room_id, batch, sync_config, since_token, now_token, full_state
     ):
         """ Works out the difference in state between the start of the timeline
@@ -759,16 +752,16 @@ class SyncHandler(object):
 
             if full_state:
                 if batch:
-                    current_state_ids = yield self.state_store.get_state_ids_for_event(
+                    current_state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
 
-                    state_ids = yield self.state_store.get_state_ids_for_event(
+                    state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
 
                 else:
-                    current_state_ids = yield self.get_state_at(
+                    current_state_ids = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -783,13 +776,13 @@ class SyncHandler(object):
                 )
             elif batch.limited:
                 if batch:
-                    state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
+                    state_at_timeline_start = await self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
-                    state_at_timeline_start = yield self.get_state_at(
+                    state_at_timeline_start = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -807,19 +800,19 @@ class SyncHandler(object):
                 # about them).
                 state_filter = StateFilter.all()
 
-                state_at_previous_sync = yield self.get_state_at(
+                state_at_previous_sync = await self.get_state_at(
                     room_id, stream_position=since_token, state_filter=state_filter
                 )
 
                 if batch:
-                    current_state_ids = yield self.state_store.get_state_ids_for_event(
+                    current_state_ids = await self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
                 else:
                     # Its not clear how we get here, but empirically we do
                     # (#5407). Logging has been added elsewhere to try and
                     # figure out where this state comes from.
-                    current_state_ids = yield self.get_state_at(
+                    current_state_ids = await self.get_state_at(
                         room_id, stream_position=now_token, state_filter=state_filter
                     )
 
@@ -843,7 +836,7 @@ class SyncHandler(object):
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = yield self.state_store.get_state_ids_for_event(
+                        state_ids = await self.state_store.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(
@@ -883,7 +876,7 @@ class SyncHandler(object):
 
         state = {}
         if state_ids:
-            state = yield self.store.get_events(list(state_ids.values()))
+            state = await self.store.get_events(list(state_ids.values()))
 
         return {
             (e.type, e.state_key): e
@@ -892,10 +885,9 @@ class SyncHandler(object):
             )
         }
 
-    @defer.inlineCallbacks
-    def unread_notifs_for_room_id(self, room_id, sync_config):
+    async def unread_notifs_for_room_id(self, room_id, sync_config):
         with Measure(self.clock, "unread_notifs_for_room_id"):
-            last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
+            last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
                 room_id=room_id,
                 receipt_type="m.read",
@@ -903,7 +895,7 @@ class SyncHandler(object):
 
             notifs = []
             if last_unread_event_id:
-                notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
+                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
                     room_id, sync_config.user.to_string(), last_unread_event_id
                 )
                 return notifs
@@ -912,8 +904,9 @@ class SyncHandler(object):
         # count is whatever it was last time.
         return None
 
-    @defer.inlineCallbacks
-    def generate_sync_result(self, sync_config, since_token=None, full_state=False):
+    async def generate_sync_result(
+        self, sync_config, since_token=None, full_state=False
+    ):
         """Generates a sync result.
 
         Args:
@@ -928,7 +921,7 @@ class SyncHandler(object):
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
         # Always use the `now_token` in `SyncResultBuilder`
-        now_token = yield self.event_sources.get_current_token()
+        now_token = await self.event_sources.get_current_token()
 
         logger.info(
             "Calculating sync response for %r between %s and %s",
@@ -944,10 +937,9 @@ class SyncHandler(object):
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
         else:
-            joined_room_ids = yield self.get_rooms_for_user_at(
+            joined_room_ids = await self.get_rooms_for_user_at(
                 user_id, now_token.room_stream_id
             )
-
         sync_result_builder = SyncResultBuilder(
             sync_config,
             full_state,
@@ -956,11 +948,11 @@ class SyncHandler(object):
             joined_room_ids=joined_room_ids,
         )
 
-        account_data_by_room = yield self._generate_sync_entry_for_account_data(
+        account_data_by_room = await self._generate_sync_entry_for_account_data(
             sync_result_builder
         )
 
-        res = yield self._generate_sync_entry_for_rooms(
+        res = await self._generate_sync_entry_for_rooms(
             sync_result_builder, account_data_by_room
         )
         newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
@@ -970,13 +962,13 @@ class SyncHandler(object):
             since_token is None and sync_config.filter_collection.blocks_all_presence()
         )
         if self.hs_config.use_presence and not block_all_presence_data:
-            yield self._generate_sync_entry_for_presence(
+            await self._generate_sync_entry_for_presence(
                 sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
             )
 
-        yield self._generate_sync_entry_for_to_device(sync_result_builder)
+        await self._generate_sync_entry_for_to_device(sync_result_builder)
 
-        device_lists = yield self._generate_sync_entry_for_device_list(
+        device_lists = await self._generate_sync_entry_for_device_list(
             sync_result_builder,
             newly_joined_rooms=newly_joined_rooms,
             newly_joined_or_invited_users=newly_joined_or_invited_users,
@@ -987,11 +979,11 @@ class SyncHandler(object):
         device_id = sync_config.device_id
         one_time_key_counts = {}
         if device_id:
-            one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+            one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
 
-        yield self._generate_sync_entry_for_groups(sync_result_builder)
+        await self._generate_sync_entry_for_groups(sync_result_builder)
 
         # debug for https://github.com/matrix-org/synapse/issues/4422
         for joined_room in sync_result_builder.joined:
@@ -1015,18 +1007,17 @@ class SyncHandler(object):
         )
 
     @measure_func("_generate_sync_entry_for_groups")
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_groups(self, sync_result_builder):
+    async def _generate_sync_entry_for_groups(self, sync_result_builder):
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
         now_token = sync_result_builder.now_token
 
         if since_token and since_token.groups_key:
-            results = yield self.store.get_groups_changes_for_user(
+            results = await self.store.get_groups_changes_for_user(
                 user_id, since_token.groups_key, now_token.groups_key
             )
         else:
-            results = yield self.store.get_all_groups_for_user(
+            results = await self.store.get_all_groups_for_user(
                 user_id, now_token.groups_key
             )
 
@@ -1059,8 +1050,7 @@ class SyncHandler(object):
         )
 
     @measure_func("_generate_sync_entry_for_device_list")
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_device_list(
+    async def _generate_sync_entry_for_device_list(
         self,
         sync_result_builder,
         newly_joined_rooms,
@@ -1108,32 +1098,32 @@ class SyncHandler(object):
             # room with by looking at all users that have left a room plus users
             # that were in a room we've left.
 
-            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+            users_who_share_room = await self.store.get_users_who_share_room_with_user(
                 user_id
             )
 
             # Step 1a, check for changes in devices of users we share a room with
-            users_that_have_changed = yield self.store.get_users_whose_devices_changed(
+            users_that_have_changed = await self.store.get_users_whose_devices_changed(
                 since_token.device_list_key, users_who_share_room
             )
 
             # Step 1b, check for newly joined rooms
             for room_id in newly_joined_rooms:
-                joined_users = yield self.state.get_current_users_in_room(room_id)
+                joined_users = await self.state.get_current_users_in_room(room_id)
                 newly_joined_or_invited_users.update(joined_users)
 
             # TODO: Check that these users are actually new, i.e. either they
             # weren't in the previous sync *or* they left and rejoined.
             users_that_have_changed.update(newly_joined_or_invited_users)
 
-            user_signatures_changed = yield self.store.get_users_whose_signatures_changed(
+            user_signatures_changed = await self.store.get_users_whose_signatures_changed(
                 user_id, since_token.device_list_key
             )
             users_that_have_changed.update(user_signatures_changed)
 
             # Now find users that we no longer track
             for room_id in newly_left_rooms:
-                left_users = yield self.state.get_current_users_in_room(room_id)
+                left_users = await self.state.get_current_users_in_room(room_id)
                 newly_left_users.update(left_users)
 
             # Remove any users that we still share a room with.
@@ -1143,8 +1133,7 @@ class SyncHandler(object):
         else:
             return DeviceLists(changed=[], left=[])
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_to_device(self, sync_result_builder):
+    async def _generate_sync_entry_for_to_device(self, sync_result_builder):
         """Generates the portion of the sync response. Populates
         `sync_result_builder` with the result.
 
@@ -1165,14 +1154,14 @@ class SyncHandler(object):
             # We only delete messages when a new message comes in, but that's
             # fine so long as we delete them at some point.
 
-            deleted = yield self.store.delete_messages_for_device(
+            deleted = await self.store.delete_messages_for_device(
                 user_id, device_id, since_stream_id
             )
             logger.debug(
                 "Deleted %d to-device messages up to %d", deleted, since_stream_id
             )
 
-            messages, stream_id = yield self.store.get_new_messages_for_device(
+            messages, stream_id = await self.store.get_new_messages_for_device(
                 user_id, device_id, since_stream_id, now_token.to_device_key
             )
 
@@ -1190,8 +1179,7 @@ class SyncHandler(object):
         else:
             sync_result_builder.to_device = []
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_account_data(self, sync_result_builder):
+    async def _generate_sync_entry_for_account_data(self, sync_result_builder):
         """Generates the account data portion of the sync response. Populates
         `sync_result_builder` with the result.
 
@@ -1209,25 +1197,25 @@ class SyncHandler(object):
             (
                 account_data,
                 account_data_by_room,
-            ) = yield self.store.get_updated_account_data_for_user(
+            ) = await self.store.get_updated_account_data_for_user(
                 user_id, since_token.account_data_key
             )
 
-            push_rules_changed = yield self.store.have_push_rules_changed_for_user(
+            push_rules_changed = await self.store.have_push_rules_changed_for_user(
                 user_id, int(since_token.push_rules_key)
             )
 
             if push_rules_changed:
-                account_data["m.push_rules"] = yield self.push_rules_for_user(
+                account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
         else:
             (
                 account_data,
                 account_data_by_room,
-            ) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
+            ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
 
-            account_data["m.push_rules"] = yield self.push_rules_for_user(
+            account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
@@ -1242,8 +1230,7 @@ class SyncHandler(object):
 
         return account_data_by_room
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_presence(
+    async def _generate_sync_entry_for_presence(
         self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
     ):
         """Generates the presence portion of the sync response. Populates the
@@ -1271,7 +1258,7 @@ class SyncHandler(object):
             presence_key = None
             include_offline = False
 
-        presence, presence_key = yield presence_source.get_new_events(
+        presence, presence_key = await presence_source.get_new_events(
             user=user,
             from_key=presence_key,
             is_guest=sync_config.is_guest,
@@ -1283,12 +1270,12 @@ class SyncHandler(object):
 
         extra_users_ids = set(newly_joined_or_invited_users)
         for room_id in newly_joined_rooms:
-            users = yield self.state.get_current_users_in_room(room_id)
+            users = await self.state.get_current_users_in_room(room_id)
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
         if extra_users_ids:
-            states = yield self.presence_handler.get_states(extra_users_ids)
+            states = await self.presence_handler.get_states(extra_users_ids)
             presence.extend(states)
 
             # Deduplicate the presence entries so that there's at most one per user
@@ -1298,8 +1285,9 @@ class SyncHandler(object):
 
         sync_result_builder.presence = presence
 
-    @defer.inlineCallbacks
-    def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room):
+    async def _generate_sync_entry_for_rooms(
+        self, sync_result_builder, account_data_by_room
+    ):
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
 
@@ -1321,7 +1309,7 @@ class SyncHandler(object):
         if block_all_room_ephemeral:
             ephemeral_by_room = {}
         else:
-            now_token, ephemeral_by_room = yield self.ephemeral_by_room(
+            now_token, ephemeral_by_room = await self.ephemeral_by_room(
                 sync_result_builder,
                 now_token=sync_result_builder.now_token,
                 since_token=sync_result_builder.since_token,
@@ -1333,16 +1321,16 @@ class SyncHandler(object):
         since_token = sync_result_builder.since_token
         if not sync_result_builder.full_state:
             if since_token and not ephemeral_by_room and not account_data_by_room:
-                have_changed = yield self._have_rooms_changed(sync_result_builder)
+                have_changed = await self._have_rooms_changed(sync_result_builder)
                 if not have_changed:
-                    tags_by_room = yield self.store.get_updated_tags(
+                    tags_by_room = await self.store.get_updated_tags(
                         user_id, since_token.account_data_key
                     )
                     if not tags_by_room:
                         logger.debug("no-oping sync")
                         return [], [], [], []
 
-        ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
+        ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
             "m.ignored_user_list", user_id=user_id
         )
 
@@ -1352,18 +1340,18 @@ class SyncHandler(object):
             ignored_users = frozenset()
 
         if since_token:
-            res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
+            res = await self._get_rooms_changed(sync_result_builder, ignored_users)
             room_entries, invited, newly_joined_rooms, newly_left_rooms = res
 
-            tags_by_room = yield self.store.get_updated_tags(
+            tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            res = yield self._get_all_rooms(sync_result_builder, ignored_users)
+            res = await self._get_all_rooms(sync_result_builder, ignored_users)
             room_entries, invited, newly_joined_rooms = res
             newly_left_rooms = []
 
-            tags_by_room = yield self.store.get_tags_for_user(user_id)
+            tags_by_room = await self.store.get_tags_for_user(user_id)
 
         def handle_room_entries(room_entry):
             return self._generate_room_entry(
@@ -1376,7 +1364,7 @@ class SyncHandler(object):
                 always_include=sync_result_builder.full_state,
             )
 
-        yield concurrently_execute(handle_room_entries, room_entries, 10)
+        await concurrently_execute(handle_room_entries, room_entries, 10)
 
         sync_result_builder.invited.extend(invited)
 
@@ -1410,8 +1398,7 @@ class SyncHandler(object):
             newly_left_users,
         )
 
-    @defer.inlineCallbacks
-    def _have_rooms_changed(self, sync_result_builder):
+    async def _have_rooms_changed(self, sync_result_builder):
         """Returns whether there may be any new events that should be sent down
         the sync. Returns True if there are.
         """
@@ -1422,7 +1409,7 @@ class SyncHandler(object):
         assert since_token
 
         # Get a list of membership change events that have happened.
-        rooms_changed = yield self.store.get_membership_changes_for_user(
+        rooms_changed = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
@@ -1435,8 +1422,7 @@ class SyncHandler(object):
                 return True
         return False
 
-    @defer.inlineCallbacks
-    def _get_rooms_changed(self, sync_result_builder, ignored_users):
+    async def _get_rooms_changed(self, sync_result_builder, ignored_users):
         """Gets the the changes that have happened since the last sync.
 
         Args:
@@ -1461,7 +1447,7 @@ class SyncHandler(object):
         assert since_token
 
         # Get a list of membership change events that have happened.
-        rooms_changed = yield self.store.get_membership_changes_for_user(
+        rooms_changed = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
@@ -1499,11 +1485,11 @@ class SyncHandler(object):
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = yield self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(room_id, since_token)
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
-                    old_mem_ev = yield self.store.get_event(
+                    old_mem_ev = await self.store.get_event(
                         old_mem_ev_id, allow_none=True
                     )
 
@@ -1536,13 +1522,13 @@ class SyncHandler(object):
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = yield self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(room_id, since_token)
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )
                         old_mem_ev = None
                         if old_mem_ev_id:
-                            old_mem_ev = yield self.store.get_event(
+                            old_mem_ev = await self.store.get_event(
                                 old_mem_ev_id, allow_none=True
                             )
                     if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
@@ -1566,7 +1552,7 @@ class SyncHandler(object):
 
             if leave_events:
                 leave_event = leave_events[-1]
-                leave_stream_token = yield self.store.get_stream_token_for_event(
+                leave_stream_token = await self.store.get_stream_token_for_event(
                     leave_event.event_id
                 )
                 leave_token = since_token.copy_and_replace(
@@ -1603,7 +1589,7 @@ class SyncHandler(object):
         timeline_limit = sync_config.filter_collection.timeline_limit()
 
         # Get all events for rooms we're currently joined to.
-        room_to_events = yield self.store.get_room_events_stream_for_rooms(
+        room_to_events = await self.store.get_room_events_stream_for_rooms(
             room_ids=sync_result_builder.joined_room_ids,
             from_key=since_token.room_key,
             to_key=now_token.room_key,
@@ -1652,8 +1638,7 @@ class SyncHandler(object):
 
         return room_entries, invited, newly_joined_rooms, newly_left_rooms
 
-    @defer.inlineCallbacks
-    def _get_all_rooms(self, sync_result_builder, ignored_users):
+    async def _get_all_rooms(self, sync_result_builder, ignored_users):
         """Returns entries for all rooms for the user.
 
         Args:
@@ -1677,7 +1662,7 @@ class SyncHandler(object):
             Membership.BAN,
         )
 
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_user_where_membership_is(
             user_id=user_id, membership_list=membership_list
         )
 
@@ -1700,7 +1685,7 @@ class SyncHandler(object):
             elif event.membership == Membership.INVITE:
                 if event.sender in ignored_users:
                     continue
-                invite = yield self.store.get_event(event.event_id)
+                invite = await self.store.get_event(event.event_id)
                 invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
             elif event.membership in (Membership.LEAVE, Membership.BAN):
                 # Always send down rooms we were banned or kicked from.
@@ -1726,8 +1711,7 @@ class SyncHandler(object):
 
         return room_entries, invited, []
 
-    @defer.inlineCallbacks
-    def _generate_room_entry(
+    async def _generate_room_entry(
         self,
         sync_result_builder,
         ignored_users,
@@ -1769,7 +1753,7 @@ class SyncHandler(object):
         since_token = room_builder.since_token
         upto_token = room_builder.upto_token
 
-        batch = yield self._load_filtered_recents(
+        batch = await self._load_filtered_recents(
             room_id,
             sync_config,
             now_token=upto_token,
@@ -1796,7 +1780,7 @@ class SyncHandler(object):
         # tag was added by synapse e.g. for server notice rooms.
         if full_state:
             user_id = sync_result_builder.sync_config.user.to_string()
-            tags = yield self.store.get_tags_for_room(user_id, room_id)
+            tags = await self.store.get_tags_for_room(user_id, room_id)
 
             # If there aren't any tags, don't send the empty tags list down
             # sync
@@ -1821,7 +1805,7 @@ class SyncHandler(object):
         ):
             return
 
-        state = yield self.compute_state_delta(
+        state = await self.compute_state_delta(
             room_id, batch, sync_config, since_token, now_token, full_state=full_state
         )
 
@@ -1844,7 +1828,7 @@ class SyncHandler(object):
             )
             or since_token is None
         ):
-            summary = yield self.compute_summary(
+            summary = await self.compute_summary(
                 room_id, sync_config, batch, state, now_token
             )
 
@@ -1861,7 +1845,7 @@ class SyncHandler(object):
             )
 
             if room_sync or always_include:
-                notifs = yield self.unread_notifs_for_room_id(room_id, sync_config)
+                notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
                 if notifs is not None:
                     unread_notifications["notification_count"] = notifs["notify_count"]
@@ -1887,8 +1871,7 @@ class SyncHandler(object):
         else:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
-    @defer.inlineCallbacks
-    def get_rooms_for_user_at(self, user_id, stream_ordering):
+    async def get_rooms_for_user_at(self, user_id, stream_ordering):
         """Get set of joined rooms for a user at the given stream ordering.
 
         The stream ordering *must* be recent, otherwise this may throw an
@@ -1903,7 +1886,7 @@ class SyncHandler(object):
             Deferred[frozenset[str]]: Set of room_ids the user is in at given
             stream_ordering.
         """
-        joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id)
+        joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
 
         joined_room_ids = set()
 
@@ -1921,10 +1904,10 @@ class SyncHandler(object):
 
             logger.info("User joined room after current token: %s", room_id)
 
-            extrems = yield self.store.get_forward_extremeties_for_room(
+            extrems = await self.store.get_forward_extremeties_for_room(
                 room_id, stream_ordering
             )
-            users_in_room = yield self.state.get_current_users_in_room(room_id, extrems)
+            users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
             if user_id in users_in_room:
                 joined_room_ids.add(room_id)
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 856337b7e2..6f78454322 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -313,7 +313,7 @@ class TypingNotificationEventSource(object):
 
                 events.append(self._make_event_for(room_id))
 
-            return events, handler._latest_room_serial
+            return defer.succeed((events, handler._latest_room_serial))
 
     def get_current_key(self):
         return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
 See doc/log_contexts.rst for details on how this works.
 """
 
+import inspect
 import logging
 import threading
 import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
 
 
 def make_deferred_yieldable(deferred):
-    """Given a deferred, make it follow the Synapse logcontext rules:
+    """Given a deferred (or coroutine), make it follow the Synapse logcontext
+    rules:
 
     If the deferred has completed (or is not actually a Deferred), essentially
     does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
+    if inspect.isawaitable(deferred):
+        # If we're given a coroutine we convert it to a deferred so that we
+        # run it and find out if it immediately finishes, it it does then we
+        # don't need to fiddle with log contexts at all and can return
+        # immediately.
+        deferred = defer.ensureDeferred(deferred)
+
     if not isinstance(deferred, defer.Deferred):
         return deferred
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 735b882363..305b9b0178 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -175,4 +175,4 @@ class ModuleApi(object):
         Returns:
             Deferred[object]: result of func
         """
-        return self._store.runInteraction(desc, func, *args, **kwargs)
+        return self._store.db.runInteraction(desc, func, *args, **kwargs)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index af161a81d7..5f5f765bea 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -304,8 +304,7 @@ class Notifier(object):
         without waking up any of the normal user event streams"""
         self.notify_replication()
 
-    @defer.inlineCallbacks
-    def wait_for_events(
+    async def wait_for_events(
         self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
     ):
         """Wait until the callback returns a non empty response or the
@@ -313,9 +312,9 @@ class Notifier(object):
         """
         user_stream = self.user_to_user_stream.get(user_id)
         if user_stream is None:
-            current_token = yield self.event_sources.get_current_token()
+            current_token = await self.event_sources.get_current_token()
             if room_ids is None:
-                room_ids = yield self.store.get_rooms_for_user(user_id)
+                room_ids = await self.store.get_rooms_for_user(user_id)
             user_stream = _NotifierUserStream(
                 user_id=user_id,
                 rooms=room_ids,
@@ -344,11 +343,11 @@ class Notifier(object):
                         self.hs.get_reactor(),
                     )
                     with PreserveLoggingContext():
-                        yield listener.deferred
+                        await listener.deferred
 
                     current_token = user_stream.current_token
 
-                    result = yield callback(prev_token, current_token)
+                    result = await callback(prev_token, current_token)
                     if result:
                         break
 
@@ -364,12 +363,11 @@ class Notifier(object):
             # This happened if there was no timeout or if the timeout had
             # already expired.
             current_token = user_stream.current_token
-            result = yield callback(prev_token, current_token)
+            result = await callback(prev_token, current_token)
 
         return result
 
-    @defer.inlineCallbacks
-    def get_events_for(
+    async def get_events_for(
         self,
         user,
         pagination_config,
@@ -391,15 +389,14 @@ class Notifier(object):
         """
         from_token = pagination_config.from_token
         if not from_token:
-            from_token = yield self.event_sources.get_current_token()
+            from_token = await self.event_sources.get_current_token()
 
         limit = pagination_config.limit
 
-        room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
+        room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
         is_peeking = not is_joined
 
-        @defer.inlineCallbacks
-        def check_for_updates(before_token, after_token):
+        async def check_for_updates(before_token, after_token):
             if not after_token.is_after(before_token):
                 return EventStreamResult([], (from_token, from_token))
 
@@ -415,7 +412,7 @@ class Notifier(object):
                 if only_keys and name not in only_keys:
                     continue
 
-                new_events, new_key = yield source.get_new_events(
+                new_events, new_key = await source.get_new_events(
                     user=user,
                     from_key=getattr(from_token, keyname),
                     limit=limit,
@@ -425,7 +422,7 @@ class Notifier(object):
                 )
 
                 if name == "room":
-                    new_events = yield filter_events_for_client(
+                    new_events = await filter_events_for_client(
                         self.storage,
                         user.to_string(),
                         new_events,
@@ -461,7 +458,7 @@ class Notifier(object):
                 user_id_for_stream,
             )
 
-        result = yield self.wait_for_events(
+        result = await self.wait_for_events(
             user_id_for_stream,
             timeout,
             check_for_updates,
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 6ece1d6745..b91a528245 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ import six
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
 from ._slaved_id_tracker import SlavedIdTracker
@@ -35,8 +36,8 @@ def __func__(inp):
 
 
 class BaseSlavedStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = SlavedIdTracker(
                 db_conn, "cache_invalidation_stream", "stream_id"
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bc2f6a12ae..ebe94909cb 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
 from synapse.storage.data_stores.main.tags import TagsWorkerStore
+from synapse.storage.database import Database
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = SlavedIdTracker(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
 
-        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
+        super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index b4f58cea19..fbf996e33a 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.database import Database
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
@@ -21,8 +22,8 @@ from ._base import BaseSlavedStore
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedClientIpStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 9fb6c5c6ff..0c237c6e0f 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -16,13 +16,14 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_max_stream_id", "stream_id"
         )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index de50748c30..dc625e0d7a 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage.data_stores.main.devices import DeviceWorkerStore
 from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d0a0eaf75b..29f35b9915 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
 from synapse.storage.data_stores.main.stream import StreamWorkerStore
 from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -59,13 +60,13 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
         self._backfill_id_gen = SlavedIdTracker(
             db_conn, "events", "stream_ordering", step=-1
         )
 
-        super(SlavedEventStore, self).__init__(db_conn, hs)
+        super(SlavedEventStore, self).__init__(database, db_conn, hs)
 
     # Cached functions can't be accessed through a class instance so we need
     # to reach inside the __dict__ to extract them.
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 5c84ebd125..bcb0688954 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,13 +14,14 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.filtering import FilteringStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedFilteringStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
     get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 28a46edd28..69a4ae42f9 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.storage import DataStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore, __func__
@@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedGroupServerStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 747ced0c84..f552e7c972 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -15,6 +15,7 @@
 
 from synapse.storage import DataStore
 from synapse.storage.data_stores.main.presence import PresenceStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore, __func__
@@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPresenceStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPresenceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
         self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 3655f05e54..eebd5a1fb6 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,17 +15,18 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.storage.database import Database
 
 from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._push_rules_stream_id_gen = SlavedIdTracker(
             db_conn, "push_rules_stream", "stream_id"
         )
-        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
+        super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
 
     def get_push_rules_stream_token(self):
         return (
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index b4331d0799..f22c2d44a3 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,14 +15,15 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPusherStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPusherStore, self).__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 43d823c601..d40dc6e1f5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
+        super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index d9ad386b28..3a20f45316 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,14 +14,15 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.room import RoomWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1eac8a44c5..4f47562c1b 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -103,11 +103,16 @@ class ProfileAvatarURLRestServlet(RestServlet):
 
         content = parse_json_object_from_request(request)
         try:
-            new_name = content["avatar_url"]
+            new_avatar_url = content.get("avatar_url")
         except Exception:
-            return 400, "Unable to parse name"
+            return 400, "Unable to parse avatar_url"
+
+        if new_avatar_url is None:
+            return 400, "Missing required key: avatar_url"
 
-        await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
+        await self.profile_handler.set_avatar_url(
+            user, requester, new_avatar_url, is_admin
+        )
 
         return 200, {}
 
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index fb0d02aa83..6b978be876 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource):
 
         logger.info("Running url preview cache expiry")
 
-        if not (yield self.store.has_completed_background_updates()):
+        if not (yield self.store.db.updates.has_completed_background_updates()):
             logger.info("Still running DB updates; skipping expiry")
             return
 
diff --git a/synapse/server.py b/synapse/server.py
index be9af7f986..2db3dab221 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -238,8 +238,7 @@ class HomeServer(object):
     def setup(self):
         logger.info("Setting up.")
         with self.get_db_conn() as conn:
-            datastore = self.DATASTORE_CLASS(conn, self)
-            self.datastores = DataStores(datastore, conn, self)
+            self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
             conn.commit()
         self.start_time = int(self.get_clock().time())
         logger.info("Finished setting up.")
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 0460fe8cc9..ec89f645d4 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -17,10 +17,10 @@
 """
 The storage layer is split up into multiple parts to allow Synapse to run
 against different configurations of databases (e.g. single or multiple
-databases). The `data_stores` are classes that talk directly to a single
-database and have associated schemas, background updates, etc. On top of those
-there are (or will be) classes that provide high level interfaces that combine
-calls to multiple `data_stores`.
+databases). The `Database` class represents a single physical database. The
+`data_stores` are classes that talk directly to a `Database` instance and have
+associated schemas, background updates, etc. On top of those there are classes
+that provide high level interfaces that combine calls to multiple `data_stores`.
 
 There are also schemas that get applied to every database, regardless of the
 data stores associated with them (e.g. the schema version tables), which are
@@ -49,15 +49,3 @@ class Storage(object):
         self.persistence = EventsPersistenceStorage(hs, stores)
         self.purge_events = PurgeEventsStorage(hs, stores)
         self.state = StateGroupStorage(hs, stores)
-
-
-def are_all_users_on_domain(txn, database_engine, domain):
-    sql = database_engine.convert_param_style(
-        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
-    )
-    pat = "%:" + domain
-    txn.execute(sql, (pat,))
-    num_not_matching = txn.fetchall()[0][0]
-    if num_not_matching == 0:
-        return True
-    return False
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0d7c7dff27..b7637b5dc0 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,1304 +16,34 @@
 # limitations under the License.
 import logging
 import random
-import sys
-import time
-from typing import Iterable, Tuple
 
-from six import PY2, iteritems, iterkeys, itervalues
-from six.moves import builtins, intern, range
+from six import PY2
+from six.moves import builtins
 
 from canonicaljson import json
-from prometheus_client import Histogram
 
-from twisted.internet import defer
-
-from synapse.api.errors import StoreError
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.database import LoggingTransaction  # noqa: F401
+from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
+from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
-from synapse.util.stringutils import exception_to_unicode
-
-# import a function which will return a monotonic time, in seconds
-try:
-    # on python 3, use time.monotonic, since time.clock can go backwards
-    from time import monotonic as monotonic_time
-except ImportError:
-    # ... but python 2 doesn't have it
-    from time import clock as monotonic_time
 
 logger = logging.getLogger(__name__)
 
-try:
-    MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
-    # python 3 does not have a maximum int value
-    MAX_TXN_ID = 2 ** 63 - 1
-
-sql_logger = logging.getLogger("synapse.storage.SQL")
-transaction_logger = logging.getLogger("synapse.storage.txn")
-perf_logger = logging.getLogger("synapse.storage.TIME")
-
-sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-
-sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
-
-
-# Unique indexes which have been added in background updates. Maps from table name
-# to the name of the background update which added the unique index to that table.
-#
-# This is used by the upsert logic to figure out which tables are safe to do a proper
-# UPSERT on: until the relevant background update has completed, we
-# have to emulate an upsert by locking the table.
-#
-UNIQUE_INDEX_BACKGROUND_UPDATES = {
-    "user_ips": "user_ips_device_unique_index",
-    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
-    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
-    "event_search": "event_search_event_id_idx",
-}
-
 
-class LoggingTransaction(object):
-    """An object that almost-transparently proxies for the 'txn' object
-    passed to the constructor. Adds logging and metrics to the .execute()
-    method.
+class SQLBaseStore(object):
+    """Base class for data stores that holds helper functions.
 
-    Args:
-        txn: The database transcation object to wrap.
-        name (str): The name of this transactions for logging.
-        database_engine (Sqlite3Engine|PostgresEngine)
-        after_callbacks(list|None): A list that callbacks will be appended to
-            that have been added by `call_after` which should be run on
-            successful completion of the transaction. None indicates that no
-            callbacks should be allowed to be scheduled to run.
-        exception_callbacks(list|None): A list that callbacks will be appended
-            to that have been added by `call_on_exception` which should be run
-            if transaction ends with an error. None indicates that no callbacks
-            should be allowed to be scheduled to run.
+    Note that multiple instances of this class will exist as there will be one
+    per data store (and not one per physical database).
     """
 
-    __slots__ = [
-        "txn",
-        "name",
-        "database_engine",
-        "after_callbacks",
-        "exception_callbacks",
-    ]
-
-    def __init__(
-        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
-    ):
-        object.__setattr__(self, "txn", txn)
-        object.__setattr__(self, "name", name)
-        object.__setattr__(self, "database_engine", database_engine)
-        object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "exception_callbacks", exception_callbacks)
-
-    def call_after(self, callback, *args, **kwargs):
-        """Call the given callback on the main twisted thread after the
-        transaction has finished. Used to invalidate the caches on the
-        correct thread.
-        """
-        self.after_callbacks.append((callback, args, kwargs))
-
-    def call_on_exception(self, callback, *args, **kwargs):
-        self.exception_callbacks.append((callback, args, kwargs))
-
-    def __getattr__(self, name):
-        return getattr(self.txn, name)
-
-    def __setattr__(self, name, value):
-        setattr(self.txn, name, value)
-
-    def __iter__(self):
-        return self.txn.__iter__()
-
-    def execute_batch(self, sql, args):
-        if isinstance(self.database_engine, PostgresEngine):
-            from psycopg2.extras import execute_batch
-
-            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
-        else:
-            for val in args:
-                self.execute(sql, val)
-
-    def execute(self, sql, *args):
-        self._do_execute(self.txn.execute, sql, *args)
-
-    def executemany(self, sql, *args):
-        self._do_execute(self.txn.executemany, sql, *args)
-
-    def _make_sql_one_line(self, sql):
-        "Strip newlines out of SQL so that the loggers in the DB are on one line"
-        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
-
-    def _do_execute(self, func, sql, *args):
-        sql = self._make_sql_one_line(sql)
-
-        # TODO(paul): Maybe use 'info' and 'debug' for values?
-        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
-
-        sql = self.database_engine.convert_param_style(sql)
-        if args:
-            try:
-                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
-            except Exception:
-                # Don't let logging failures stop SQL from working
-                pass
-
-        start = time.time()
-
-        try:
-            return func(sql, *args)
-        except Exception as e:
-            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
-            raise
-        finally:
-            secs = time.time() - start
-            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
-            sql_query_timer.labels(sql.split()[0]).observe(secs)
-
-
-class PerformanceCounters(object):
-    def __init__(self):
-        self.current_counters = {}
-        self.previous_counters = {}
-
-    def update(self, key, duration_secs):
-        count, cum_time = self.current_counters.get(key, (0, 0))
-        count += 1
-        cum_time += duration_secs
-        self.current_counters[key] = (count, cum_time)
-
-    def interval(self, interval_duration_secs, limit=3):
-        counters = []
-        for name, (count, cum_time) in iteritems(self.current_counters):
-            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
-            counters.append(
-                (
-                    (cum_time - prev_time) / interval_duration_secs,
-                    count - prev_count,
-                    name,
-                )
-            )
-
-        self.previous_counters = dict(self.current_counters)
-
-        counters.sort(reverse=True)
-
-        top_n_counters = ", ".join(
-            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
-            for ratio, count, name in counters[:limit]
-        )
-
-        return top_n_counters
-
-
-class SQLBaseStore(object):
-    _TXN_ID = 0
-
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self._db_pool = hs.get_db_pool()
-
-        self._previous_txn_total_time = 0
-        self._current_txn_total_time = 0
-        self._previous_loop_ts = 0
-
-        # TODO(paul): These can eventually be removed once the metrics code
-        #   is running in mainline, and we have some nice monitoring frontends
-        #   to watch it
-        self._txn_perf_counters = PerformanceCounters()
-
         self.database_engine = hs.database_engine
-
-        # A set of tables that are not safe to use native upserts in.
-        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
-
-        # We add the user_directory_search table to the blacklist on SQLite
-        # because the existing search table does not have an index, making it
-        # unsafe to use native upserts.
-        if isinstance(self.database_engine, Sqlite3Engine):
-            self._unsafe_to_upsert_tables.add("user_directory_search")
-
-        if self.database_engine.can_native_upsert:
-            # Check ASAP (and then later, every 1s) to see if we have finished
-            # background updates of tables that aren't safe to update.
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "upsert_safety_check",
-                self._check_safe_to_upsert,
-            )
-
+        self.db = database
         self.rand = random.SystemRandom()
 
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
-        """
-        Is it safe to use native UPSERT?
-
-        If there are background updates, we will need to wait, as they may be
-        the addition of indexes that set the UNIQUE constraint that we require.
-
-        If the background updates have not completed, wait 15 sec and check again.
-        """
-        updates = yield self.simple_select_list(
-            "background_updates",
-            keyvalues=None,
-            retcols=["update_name"],
-            desc="check_background_updates",
-        )
-        updates = [x["update_name"] for x in updates]
-
-        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
-            if update_name not in updates:
-                logger.debug("Now safe to upsert in %s", table)
-                self._unsafe_to_upsert_tables.discard(table)
-
-        # If there's any updates still running, reschedule to run.
-        if updates:
-            self._clock.call_later(
-                15.0,
-                run_as_background_process,
-                "upsert_safety_check",
-                self._check_safe_to_upsert,
-            )
-
-    def start_profiling(self):
-        self._previous_loop_ts = monotonic_time()
-
-        def loop():
-            curr = self._current_txn_total_time
-            prev = self._previous_txn_total_time
-            self._previous_txn_total_time = curr
-
-            time_now = monotonic_time()
-            time_then = self._previous_loop_ts
-            self._previous_loop_ts = time_now
-
-            duration = time_now - time_then
-            ratio = (curr - prev) / duration
-
-            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
-
-            perf_logger.info(
-                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
-            )
-
-        self._clock.looping_call(loop, 10000)
-
-    def new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
-        start = monotonic_time()
-        txn_id = self._TXN_ID
-
-        # We don't really need these to be unique, so lets stop it from
-        # growing really large.
-        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
-
-        name = "%s-%x" % (desc, txn_id)
-
-        transaction_logger.debug("[TXN START] {%s}", name)
-
-        try:
-            i = 0
-            N = 5
-            while True:
-                cursor = LoggingTransaction(
-                    conn.cursor(),
-                    name,
-                    self.database_engine,
-                    after_callbacks,
-                    exception_callbacks,
-                )
-                try:
-                    r = func(cursor, *args, **kwargs)
-                    conn.commit()
-                    return r
-                except self.database_engine.module.OperationalError as e:
-                    # This can happen if the database disappears mid
-                    # transaction.
-                    logger.warning(
-                        "[TXN OPERROR] {%s} %s %d/%d",
-                        name,
-                        exception_to_unicode(e),
-                        i,
-                        N,
-                    )
-                    if i < N:
-                        i += 1
-                        try:
-                            conn.rollback()
-                        except self.database_engine.module.Error as e1:
-                            logger.warning(
-                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
-                            )
-                        continue
-                    raise
-                except self.database_engine.module.DatabaseError as e:
-                    if self.database_engine.is_deadlock(e):
-                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
-                        if i < N:
-                            i += 1
-                            try:
-                                conn.rollback()
-                            except self.database_engine.module.Error as e1:
-                                logger.warning(
-                                    "[TXN EROLL] {%s} %s",
-                                    name,
-                                    exception_to_unicode(e1),
-                                )
-                            continue
-                    raise
-                finally:
-                    # we're either about to retry with a new cursor, or we're about to
-                    # release the connection. Once we release the connection, it could
-                    # get used for another query, which might do a conn.rollback().
-                    #
-                    # In the latter case, even though that probably wouldn't affect the
-                    # results of this transaction, python's sqlite will reset all
-                    # statements on the connection [1], which will make our cursor
-                    # invalid [2].
-                    #
-                    # In any case, continuing to read rows after commit()ing seems
-                    # dubious from the PoV of ACID transactional semantics
-                    # (sqlite explicitly says that once you commit, you may see rows
-                    # from subsequent updates.)
-                    #
-                    # In psycopg2, cursors are essentially a client-side fabrication -
-                    # all the data is transferred to the client side when the statement
-                    # finishes executing - so in theory we could go on streaming results
-                    # from the cursor, but attempting to do so would make us
-                    # incompatible with sqlite, so let's make sure we're not doing that
-                    # by closing the cursor.
-                    #
-                    # (*named* cursors in psycopg2 are different and are proper server-
-                    # side things, but (a) we don't use them and (b) they are implicitly
-                    # closed by ending the transaction anyway.)
-                    #
-                    # In short, if we haven't finished with the cursor yet, that's a
-                    # problem waiting to bite us.
-                    #
-                    # TL;DR: we're done with the cursor, so we can close it.
-                    #
-                    # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
-                    # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
-                    cursor.close()
-        except Exception as e:
-            logger.debug("[TXN FAIL] {%s} %s", name, e)
-            raise
-        finally:
-            end = monotonic_time()
-            duration = end - start
-
-            LoggingContext.current_context().add_database_transaction(duration)
-
-            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
-
-            self._current_txn_total_time += duration
-            self._txn_perf_counters.update(desc, duration)
-            sql_txn_timer.labels(desc).observe(duration)
-
-    @defer.inlineCallbacks
-    def runInteraction(self, desc, func, *args, **kwargs):
-        """Starts a transaction on the database and runs a given function
-
-        Arguments:
-            desc (str): description of the transaction, for logging and metrics
-            func (func): callback function, which will be called with a
-                database transaction (twisted.enterprise.adbapi.Transaction) as
-                its first argument, followed by `args` and `kwargs`.
-
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
-
-        Returns:
-            Deferred: The result of func
-        """
-        after_callbacks = []
-        exception_callbacks = []
-
-        if LoggingContext.current_context() == LoggingContext.sentinel:
-            logger.warning("Starting db txn '%s' from sentinel context", desc)
-
-        try:
-            result = yield self.runWithConnection(
-                self.new_transaction,
-                desc,
-                after_callbacks,
-                exception_callbacks,
-                func,
-                *args,
-                **kwargs
-            )
-
-            for after_callback, after_args, after_kwargs in after_callbacks:
-                after_callback(*after_args, **after_kwargs)
-        except:  # noqa: E722, as we reraise the exception this is fine.
-            for after_callback, after_args, after_kwargs in exception_callbacks:
-                after_callback(*after_args, **after_kwargs)
-            raise
-
-        return result
-
-    @defer.inlineCallbacks
-    def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runWithConnection() method on the underlying db_pool.
-
-        Arguments:
-            func (func): callback function, which will be called with a
-                database connection (twisted.enterprise.adbapi.Connection) as
-                its first argument, followed by `args` and `kwargs`.
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
-
-        Returns:
-            Deferred: The result of func
-        """
-        parent_context = LoggingContext.current_context()
-        if parent_context == LoggingContext.sentinel:
-            logger.warning(
-                "Starting db connection from sentinel context: metrics will be lost"
-            )
-            parent_context = None
-
-        start_time = monotonic_time()
-
-        def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runWithConnection", parent_context) as context:
-                sched_duration_sec = monotonic_time() - start_time
-                sql_scheduling_timer.observe(sched_duration_sec)
-                context.add_database_scheduled(sched_duration_sec)
-
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
-
-                return func(conn, *args, **kwargs)
-
-        result = yield make_deferred_yieldable(
-            self._db_pool.runWithConnection(inner_func, *args, **kwargs)
-        )
-
-        return result
-
-    @staticmethod
-    def cursor_to_dict(cursor):
-        """Converts a SQL cursor into an list of dicts.
-
-        Args:
-            cursor : The DBAPI cursor which has executed a query.
-        Returns:
-            A list of dicts where the key is the column header.
-        """
-        col_headers = list(intern(str(column[0])) for column in cursor.description)
-        results = list(dict(zip(col_headers, row)) for row in cursor)
-        return results
-
-    def execute(self, desc, decoder, query, *args):
-        """Runs a single query for a result set.
-
-        Args:
-            decoder - The function which can resolve the cursor results to
-                something meaningful.
-            query - The query string to execute
-            *args - Query args.
-        Returns:
-            The result of decoder(results)
-        """
-
-        def interaction(txn):
-            txn.execute(query, args)
-            if decoder:
-                return decoder(txn)
-            else:
-                return txn.fetchall()
-
-        return self.runInteraction(desc, interaction)
-
-    # "Simple" SQL API methods that operate on a single table with no JOINs,
-    # no complex WHERE clauses, just a dict of values for columns.
-
-    @defer.inlineCallbacks
-    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
-        """Executes an INSERT query on the named table.
-
-        Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
-                when a conflicting row already exists. If True, False will be
-                returned by the function instead
-            desc : string giving a description of the transaction
-
-        Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
-        """
-        try:
-            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
-        except self.database_engine.module.IntegrityError:
-            # We have to do or_ignore flag at this layer, since we can't reuse
-            # a cursor after we receive an error from the db.
-            if not or_ignore:
-                raise
-            return False
-        return True
-
-    @staticmethod
-    def simple_insert_txn(txn, table, values):
-        keys, vals = zip(*values.items())
-
-        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
-            table,
-            ", ".join(k for k in keys),
-            ", ".join("?" for _ in keys),
-        )
-
-        txn.execute(sql, vals)
-
-    def simple_insert_many(self, table, values, desc):
-        return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
-
-    @staticmethod
-    def simple_insert_many_txn(txn, table, values):
-        if not values:
-            return
-
-        # This is a *slight* abomination to get a list of tuples of key names
-        # and a list of tuples of value names.
-        #
-        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
-        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
-        #
-        # The sort is to ensure that we don't rely on dictionary iteration
-        # order.
-        keys, vals = zip(
-            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
-        )
-
-        for k in keys:
-            if k != keys[0]:
-                raise RuntimeError("All items must have the same keys")
-
-        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
-            table,
-            ", ".join(k for k in keys[0]),
-            ", ".join("?" for _ in keys[0]),
-        )
-
-        txn.executemany(sql, vals)
-
-    @defer.inlineCallbacks
-    def simple_upsert(
-        self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="simple_upsert",
-        lock=True,
-    ):
-        """
-
-        `lock` should generally be set to True (the default), but can be set
-        to False if either of the following are true:
-
-        * there is a UNIQUE INDEX on the key columns. In this case a conflict
-          will cause an IntegrityError in which case this function will retry
-          the update.
-
-        * we somehow know that we are the only thread which will be updating
-          this table.
-
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        attempts = 0
-        while True:
-            try:
-                result = yield self.runInteraction(
-                    desc,
-                    self.simple_upsert_txn,
-                    table,
-                    keyvalues,
-                    values,
-                    insertion_values,
-                    lock=lock,
-                )
-                return result
-            except self.database_engine.module.IntegrityError as e:
-                attempts += 1
-                if attempts >= 5:
-                    # don't retry forever, because things other than races
-                    # can cause IntegrityErrors
-                    raise
-
-                # presumably we raced with another transaction: let's retry.
-                logger.warning(
-                    "IntegrityError when upserting into %s; retrying: %s", table, e
-                )
-
-    def simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
-        """
-        Pick the UPSERT method which works best on the platform. Either the
-        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
-
-        Args:
-            txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
-            return self.simple_upsert_txn_native_upsert(
-                txn, table, keyvalues, values, insertion_values=insertion_values
-            )
-        else:
-            return self.simple_upsert_txn_emulated(
-                txn,
-                table,
-                keyvalues,
-                values,
-                insertion_values=insertion_values,
-                lock=lock,
-            )
-
-    def simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
-        """
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            bool: Return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        # We need to lock the table :(, unless we're *really* careful
-        if lock:
-            self.database_engine.lock_table(txn, table)
-
-        def _getwhere(key):
-            # If the value we're passing in is None (aka NULL), we need to use
-            # IS, not =, as NULL = NULL equals NULL (False).
-            if keyvalues[key] is None:
-                return "%s IS ?" % (key,)
-            else:
-                return "%s = ?" % (key,)
-
-        if not values:
-            # If `values` is empty, then all of the values we care about are in
-            # the unique key, so there is nothing to UPDATE. We can just do a
-            # SELECT instead to see if it exists.
-            sql = "SELECT 1 FROM %s WHERE %s" % (
-                table,
-                " AND ".join(_getwhere(k) for k in keyvalues),
-            )
-            sqlargs = list(keyvalues.values())
-            txn.execute(sql, sqlargs)
-            if txn.fetchall():
-                # We have an existing record.
-                return False
-        else:
-            # First try to update.
-            sql = "UPDATE %s SET %s WHERE %s" % (
-                table,
-                ", ".join("%s = ?" % (k,) for k in values),
-                " AND ".join(_getwhere(k) for k in keyvalues),
-            )
-            sqlargs = list(values.values()) + list(keyvalues.values())
-
-            txn.execute(sql, sqlargs)
-            if txn.rowcount > 0:
-                # successfully updated at least one row.
-                return False
-
-        # We didn't find any existing rows, so insert a new one
-        allvalues = {}
-        allvalues.update(keyvalues)
-        allvalues.update(values)
-        allvalues.update(insertion_values)
-
-        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
-            table,
-            ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues),
-        )
-        txn.execute(sql, list(allvalues.values()))
-        # successfully inserted
-        return True
-
-    def simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
-        """
-        Use the native UPSERT functionality in recent PostgreSQL versions.
-
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
-        """
-        allvalues = {}
-        allvalues.update(keyvalues)
-        allvalues.update(insertion_values)
-
-        if not values:
-            latter = "NOTHING"
-        else:
-            allvalues.update(values)
-            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
-
-        sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
-            table,
-            ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues),
-            ", ".join(k for k in keyvalues),
-            latter,
-        )
-        txn.execute(sql, list(allvalues.values()))
-
-    def simple_upsert_many_txn(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
-            return self.simple_upsert_many_txn_native_upsert(
-                txn, table, key_names, key_values, value_names, value_values
-            )
-        else:
-            return self.simple_upsert_many_txn_emulated(
-                txn, table, key_names, key_values, value_names, value_values
-            )
-
-    def simple_upsert_many_txn_emulated(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times, but without native UPSERT support or batching.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        # No value columns, therefore make a blank list so that the following
-        # zip() works correctly.
-        if not value_names:
-            value_values = [() for x in range(len(key_values))]
-
-        for keyv, valv in zip(key_values, value_values):
-            _keys = {x: y for x, y in zip(key_names, keyv)}
-            _vals = {x: y for x, y in zip(value_names, valv)}
-
-            self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
-
-    def simple_upsert_many_txn_native_upsert(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times, using batching where possible.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        allnames = []
-        allnames.extend(key_names)
-        allnames.extend(value_names)
-
-        if not value_names:
-            # No value columns, therefore make a blank list so that the
-            # following zip() works correctly.
-            latter = "NOTHING"
-            value_values = [() for x in range(len(key_values))]
-        else:
-            latter = "UPDATE SET " + ", ".join(
-                k + "=EXCLUDED." + k for k in value_names
-            )
-
-        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
-            table,
-            ", ".join(k for k in allnames),
-            ", ".join("?" for _ in allnames),
-            ", ".join(key_names),
-            latter,
-        )
-
-        args = []
-
-        for x, y in zip(key_values, value_values):
-            args.append(tuple(x) + tuple(y))
-
-        return txn.execute_batch(sql, args)
-
-    def simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
-    ):
-        """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning multiple columns from it.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
-        """
-        return self.runInteraction(
-            desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
-        )
-
-    def simple_select_one_onecol(
-        self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="simple_select_one_onecol",
-    ):
-        """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning a single column from it.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
-        """
-        return self.runInteraction(
-            desc,
-            self.simple_select_one_onecol_txn,
-            table,
-            keyvalues,
-            retcol,
-            allow_none=allow_none,
-        )
-
-    @classmethod
-    def simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
-        ret = cls.simple_select_onecol_txn(
-            txn, table=table, keyvalues=keyvalues, retcol=retcol
-        )
-
-        if ret:
-            return ret[0]
-        else:
-            if allow_none:
-                return None
-            else:
-                raise StoreError(404, "No row found")
-
-    @staticmethod
-    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
-        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
-
-        if keyvalues:
-            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
-            txn.execute(sql, list(keyvalues.values()))
-        else:
-            txn.execute(sql)
-
-        return [r[0] for r in txn]
-
-    def simple_select_onecol(
-        self, table, keyvalues, retcol, desc="simple_select_onecol"
-    ):
-        """Executes a SELECT query on the named table, which returns a list
-        comprising of the values of the named column from the selected rows.
-
-        Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
-
-        Returns:
-            Deferred: Results in a list
-        """
-        return self.runInteraction(
-            desc, self.simple_select_onecol_txn, table, keyvalues, retcol
-        )
-
-    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc, self.simple_select_list_txn, table, keyvalues, retcols
-        )
-
-    @classmethod
-    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
-        """
-        if keyvalues:
-            sql = "SELECT %s FROM %s WHERE %s" % (
-                ", ".join(retcols),
-                table,
-                " AND ".join("%s = ?" % (k,) for k in keyvalues),
-            )
-            txn.execute(sql, list(keyvalues.values()))
-        else:
-            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-            txn.execute(sql)
-
-        return cls.cursor_to_dict(txn)
-
-    @defer.inlineCallbacks
-    def simple_select_many_batch(
-        self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="simple_select_many_batch",
-        batch_size=100,
-    ):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
-        """
-        results = []
-
-        if not iterable:
-            return results
-
-        # iterables can not be sliced, so convert it to a list first
-        it_list = list(iterable)
-
-        chunks = [
-            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
-        ]
-        for chunk in chunks:
-            rows = yield self.runInteraction(
-                desc,
-                self.simple_select_many_txn,
-                table,
-                column,
-                chunk,
-                keyvalues,
-                retcols,
-            )
-
-            results.extend(rows)
-
-        return results
-
-    @classmethod
-    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
-        """
-        if not iterable:
-            return []
-
-        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
-        clauses = [clause]
-
-        for key, value in iteritems(keyvalues):
-            clauses.append("%s = ?" % (key,))
-            values.append(value)
-
-        sql = "SELECT %s FROM %s WHERE %s" % (
-            ", ".join(retcols),
-            table,
-            " AND ".join(clauses),
-        )
-
-        txn.execute(sql, values)
-        return cls.cursor_to_dict(txn)
-
-    def simple_update(self, table, keyvalues, updatevalues, desc):
-        return self.runInteraction(
-            desc, self.simple_update_txn, table, keyvalues, updatevalues
-        )
-
-    @staticmethod
-    def simple_update_txn(txn, table, keyvalues, updatevalues):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
-        else:
-            where = ""
-
-        update_sql = "UPDATE %s SET %s %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in updatevalues),
-            where,
-        )
-
-        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
-
-        return txn.rowcount
-
-    def simple_update_one(
-        self, table, keyvalues, updatevalues, desc="simple_update_one"
-    ):
-        """Executes an UPDATE query on the named table, setting new values for
-        columns in a row matching the key values.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
-        """
-        return self.runInteraction(
-            desc, self.simple_update_one_txn, table, keyvalues, updatevalues
-        )
-
-    @classmethod
-    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
-        rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
-
-        if rowcount == 0:
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-    @staticmethod
-    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
-        select_sql = "SELECT %s FROM %s WHERE %s" % (
-            ", ".join(retcols),
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(select_sql, list(keyvalues.values()))
-        row = txn.fetchone()
-
-        if not row:
-            if allow_none:
-                return None
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if txn.rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-        return dict(zip(retcols, row))
-
-    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
-        """Executes a DELETE query on the named table, expecting to delete a
-        single row.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-        """
-        return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
-
-    @staticmethod
-    def simple_delete_one_txn(txn, table, keyvalues):
-        """Executes a DELETE query on the named table, expecting to delete a
-        single row.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-        """
-        sql = "DELETE FROM %s WHERE %s" % (
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(sql, list(keyvalues.values()))
-        if txn.rowcount == 0:
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if txn.rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-    def simple_delete(self, table, keyvalues, desc):
-        return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
-
-    @staticmethod
-    def simple_delete_txn(txn, table, keyvalues):
-        sql = "DELETE FROM %s WHERE %s" % (
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(sql, list(keyvalues.values()))
-        return txn.rowcount
-
-    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
-        return self.runInteraction(
-            desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
-        )
-
-    @staticmethod
-    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
-        """Executes a DELETE query on the named table.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-
-        Returns:
-            int: Number rows deleted
-        """
-        if not iterable:
-            return 0
-
-        sql = "DELETE FROM %s" % table
-
-        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
-        clauses = [clause]
-
-        for key, value in iteritems(keyvalues):
-            clauses.append("%s = ?" % (key,))
-            values.append(value)
-
-        if clauses:
-            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
-        txn.execute(sql, values)
-
-        return txn.rowcount
-
-    def get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
-        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
-        # It doesn't really matter how many we get, the StreamChangeCache will
-        # do the right thing to ensure it respects the max size of cache.
-        sql = (
-            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
-            " WHERE %(stream)s > ? - %(limit)s"
-            " GROUP BY %(entity)s"
-        ) % {
-            "table": table,
-            "entity": entity_column,
-            "stream": stream_column,
-            "limit": limit,
-        }
-
-        sql = self.database_engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
-        txn.execute(sql, (int(max_value),))
-
-        cache = {row[0]: int(row[1]) for row in txn}
-
-        txn.close()
-
-        if cache:
-            min_val = min(itervalues(cache))
-        else:
-            min_val = max_value
-
-        return cache, min_val
-
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
@@ -1347,165 +77,6 @@ class SQLBaseStore(object):
             # which is fine.
             pass
 
-    def simple_select_list_paginate(
-        self,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-        desc="simple_select_list_paginate",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            filters (dict[str, T] | None):
-                column names and values to filter the rows with, or None to not
-                apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc,
-            self.simple_select_list_paginate_txn,
-            table,
-            orderby,
-            start,
-            limit,
-            retcols,
-            filters=filters,
-            keyvalues=keyvalues,
-            order_direction=order_direction,
-        )
-
-    @classmethod
-    def simple_select_list_paginate_txn(
-        cls,
-        txn,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
-        select attributes with exact matches. All constraints are joined together
-        using 'AND'.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            filters (dict[str, T] | None):
-                column names and values to filter the rows with, or None to not
-                apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        if order_direction not in ["ASC", "DESC"]:
-            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
-
-        where_clause = "WHERE " if filters or keyvalues else ""
-        arg_list = []
-        if filters:
-            where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
-            arg_list += list(filters.values())
-        where_clause += " AND " if filters and keyvalues else ""
-        if keyvalues:
-            where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
-            arg_list += list(keyvalues.values())
-
-        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
-            ", ".join(retcols),
-            table,
-            where_clause,
-            orderby,
-            order_direction,
-        )
-        txn.execute(sql, arg_list + [limit, start])
-
-        return cls.cursor_to_dict(txn)
-
-    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
-        """
-
-        return self.runInteraction(
-            desc, self.simple_search_list_txn, table, term, col, retcols
-        )
-
-    @classmethod
-    def simple_search_list_txn(cls, txn, table, term, col, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
-        """
-        if term:
-            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
-            termvalues = ["%%" + term + "%%"]
-            txn.execute(sql, termvalues)
-        else:
-            return 0
-
-        return cls.cursor_to_dict(txn)
-
-
-class _RollbackButIsFineException(Exception):
-    """ This exception is used to rollback a transaction without implying
-    something went wrong.
-    """
-
-    pass
-
 
 def db_to_json(db_content):
     """
@@ -1534,30 +105,3 @@ def db_to_json(db_content):
     except Exception:
         logging.warning("Tried to decode '%r' as JSON and failed", db_content)
         raise
-
-
-def make_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable
-) -> Tuple[str, Iterable]:
-    """Returns an SQL clause that checks the given column is in the iterable.
-
-    On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
-    it expands to `column = ANY(?)`. While both DBs support the `IN` form,
-    using the `ANY` form on postgres means that it views queries with
-    different length iterables as the same, helping the query stats.
-
-    Args:
-        database_engine
-        column: Name of the column
-        iterable: The values to check the column against.
-
-    Returns:
-        A tuple of SQL query and the args
-    """
-
-    if database_engine.supports_using_any_list:
-        # This should hopefully be faster, but also makes postgres query
-        # stats easier to understand.
-        return "%s = ANY(?)" % (column,), [list(iterable)]
-    else:
-        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 06955a0537..4f97fd5ab6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -22,7 +22,6 @@ from twisted.internet import defer
 from synapse.metrics.background_process_metrics import run_as_background_process
 
 from . import engines
-from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -74,7 +73,7 @@ class BackgroundUpdatePerformance(object):
             return float(self.total_item_count) / float(self.total_duration_ms)
 
 
-class BackgroundUpdateStore(SQLBaseStore):
+class BackgroundUpdater(object):
     """ Background updates are updates to the database that run in the
     background. Each update processes a batch of data at once. We attempt to
     limit the impact of each update by monitoring how long each batch takes to
@@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore):
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, db_conn, hs):
-        super(BackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, hs, database):
+        self._clock = hs.get_clock()
+        self.db = database
+
         self._background_update_performance = {}
         self._background_update_queue = []
         self._background_update_handlers = {}
@@ -101,9 +102,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         logger.info("Starting background schema updates")
         while True:
             if sleep:
-                yield self.hs.get_clock().sleep(
-                    self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0
-                )
+                yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
 
             try:
                 result = yield self.do_next_background_update(
@@ -139,7 +138,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         # otherwise, check if there are updates to be run. This is important,
         # as we may be running on a worker which doesn't perform the bg updates
         # itself, but still wants to wait for them to happen.
-        updates = yield self.simple_select_onecol(
+        updates = yield self.db.simple_select_onecol(
             "background_updates",
             keyvalues=None,
             retcol="1",
@@ -161,7 +160,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         if update_name in self._background_update_queue:
             return False
 
-        update_exists = await self.simple_select_one_onecol(
+        update_exists = await self.db.simple_select_one_onecol(
             "background_updates",
             keyvalues={"update_name": update_name},
             retcol="1",
@@ -184,7 +183,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             no more work to do.
         """
         if not self._background_update_queue:
-            updates = yield self.simple_select_list(
+            updates = yield self.db.simple_select_list(
                 "background_updates",
                 keyvalues=None,
                 retcols=("update_name", "depends_on"),
@@ -226,7 +225,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         else:
             batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
 
-        progress_json = yield self.simple_select_one_onecol(
+        progress_json = yield self.db.simple_select_one_onecol(
             "background_updates",
             keyvalues={"update_name": update_name},
             retcol="progress_json",
@@ -380,7 +379,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
-        if isinstance(self.database_engine, engines.PostgresEngine):
+        if isinstance(self.db.engine, engines.PostgresEngine):
             runner = create_index_psql
         elif psql_only:
             runner = None
@@ -391,7 +390,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         def updater(progress, batch_size):
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
-                yield self.runWithConnection(runner)
+                yield self.db.runWithConnection(runner)
             yield self._end_background_update(update_name)
             return 1
 
@@ -413,7 +412,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_queue = []
         progress_json = json.dumps(progress)
 
-        return self.simple_insert(
+        return self.db.simple_insert(
             "background_updates",
             {"update_name": update_name, "progress_json": progress_json},
         )
@@ -429,7 +428,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_queue = [
             name for name in self._background_update_queue if name != update_name
         ]
-        return self.simple_delete_one(
+        return self.db.simple_delete_one(
             "background_updates", keyvalues={"update_name": update_name}
         )
 
@@ -444,7 +443,7 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         progress_json = json.dumps(progress)
 
-        self.simple_update_one_txn(
+        self.db.simple_update_one_txn(
             txn,
             "background_updates",
             keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cb184a98cc..cafedd5c0d 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,6 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.database import Database
+from synapse.storage.prepare_database import prepare_database
+
 
 class DataStores(object):
     """The various data stores.
@@ -20,7 +23,14 @@ class DataStores(object):
     These are low level interfaces to physical databases.
     """
 
-    def __init__(self, main_store, db_conn, hs):
-        # Note we pass in the main store here as workers use a different main
+    def __init__(self, main_store_class, db_conn, hs):
+        # Note we pass in the main store class here as workers use a different main
         # store.
-        self.main = main_store
+        database = Database(hs)
+
+        # Check that db is correctly configured.
+        database.engine.check_database(db_conn.cursor())
+
+        prepare_database(db_conn, database.engine, config=hs.config)
+
+        self.main = main_store_class(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 3720ff3088..c577c0df5f 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -20,6 +20,7 @@ import logging
 import time
 
 from synapse.api.constants import PresenceState
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     ChainedIdGenerator,
@@ -111,10 +112,20 @@ class DataStore(
     RelationsStore,
     CacheInvalidationStore,
 ):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
+        self.database_engine = database.engine
+
+        all_users_native = are_all_users_on_domain(
+            db_conn.cursor(), database.engine, hs.hostname
+        )
+        if not all_users_native:
+            raise Exception(
+                "Found users in database not native to %s!\n"
+                "You cannot changed a synapse server_name after it's been configured"
+                % (hs.hostname,)
+            )
 
         self._stream_id_gen = StreamIdGenerator(
             db_conn,
@@ -169,9 +180,11 @@ class DataStore(
         else:
             self._cache_id_gen = None
 
+        super(DataStore, self).__init__(database, db_conn, hs)
+
         self._presence_on_startup = self._get_active_presence(db_conn)
 
-        presence_cache_prefill, min_presence_val = self.get_cache_dict(
+        presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
             db_conn,
             "presence_stream",
             entity_column="user_id",
@@ -185,7 +198,7 @@ class DataStore(
         )
 
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
-        device_inbox_prefill, min_device_inbox_id = self.get_cache_dict(
+        device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
             db_conn,
             "device_inbox",
             entity_column="user_id",
@@ -200,7 +213,7 @@ class DataStore(
         )
         # The federation outbox and the local device inbox uses the same
         # stream_id generator.
-        device_outbox_prefill, min_device_outbox_id = self.get_cache_dict(
+        device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
             db_conn,
             "device_federation_outbox",
             entity_column="destination",
@@ -226,7 +239,7 @@ class DataStore(
         )
 
         events_max = self._stream_id_gen.get_current_token()
-        curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict(
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
             db_conn,
             "current_state_delta_stream",
             entity_column="room_id",
@@ -240,7 +253,7 @@ class DataStore(
             prefilled_cache=curr_state_delta_prefill,
         )
 
-        _group_updates_prefill, min_group_updates_id = self.get_cache_dict(
+        _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
             db_conn,
             "local_group_updates",
             entity_column="user_id",
@@ -260,8 +273,6 @@ class DataStore(
         # Used in _generate_user_daily_visits to keep track of progress
         self._last_user_visit_update = self._get_start_of_day()
 
-        super(DataStore, self).__init__(db_conn, hs)
-
     def take_presence_startup_info(self):
         active_on_startup = self._presence_on_startup
         self._presence_on_startup = None
@@ -281,7 +292,7 @@ class DataStore(
 
         txn = db_conn.cursor()
         txn.execute(sql, (PresenceState.OFFLINE,))
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
         txn.close()
 
         for row in rows:
@@ -294,7 +305,7 @@ class DataStore(
         Counts the number of users who used this homeserver in the last 24 hours.
         """
         yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return self.runInteraction("count_daily_users", self._count_users, yesterday)
+        return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
 
     def count_monthly_users(self):
         """
@@ -304,7 +315,7 @@ class DataStore(
         amongst other things, includes a 3 day grace period before a user counts.
         """
         thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return self.runInteraction(
+        return self.db.runInteraction(
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
@@ -404,7 +415,7 @@ class DataStore(
 
             return results
 
-        return self.runInteraction("count_r30_users", _count_r30_users)
+        return self.db.runInteraction("count_r30_users", _count_r30_users)
 
     def _get_start_of_day(self):
         """
@@ -469,7 +480,7 @@ class DataStore(
             # frequently
             self._last_user_visit_update = now
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "generate_user_daily_visits", _generate_user_daily_visits
         )
 
@@ -480,7 +491,7 @@ class DataStore(
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             table="users",
             keyvalues={},
             retcols=[
@@ -519,7 +530,7 @@ class DataStore(
         if not deactivated:
             attr_filter["deactivated"] = False
 
-        return self.simple_select_list_paginate(
+        return self.db.simple_select_list_paginate(
             desc="get_users_paginate",
             table="users",
             orderby="name",
@@ -547,10 +558,22 @@ class DataStore(
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        return self.simple_search_list(
+        return self.db.simple_search_list(
             table="users",
             term=term,
             col="name",
             retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
             desc="search_users",
         )
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+    sql = database_engine.convert_param_style(
+        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+    )
+    pat = "%:" + domain
+    txn.execute(sql, (pat,))
+    num_not_matching = txn.fetchall()[0][0]
+    if num_not_matching == 0:
+        return True
+    return False
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index b0d22faf3f..46b494b334 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+        super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
 
     @abc.abstractmethod
     def get_max_account_data_stream_id(self):
@@ -67,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_user_txn(txn):
-            rows = self.simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "account_data",
                 {"user_id": user_id},
@@ -78,7 +79,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: json.loads(row["content"]) for row in rows
             }
 
-            rows = self.simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "room_account_data",
                 {"user_id": user_id},
@@ -92,7 +93,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return global_account_data, by_room
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
@@ -102,7 +103,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         Returns:
             Deferred: A dict
         """
-        result = yield self.simple_select_one_onecol(
+        result = yield self.db.simple_select_one_onecol(
             table="account_data",
             keyvalues={"user_id": user_id, "account_data_type": data_type},
             retcol="content",
@@ -127,7 +128,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_room_txn(txn):
-            rows = self.simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "room_account_data",
                 {"user_id": user_id, "room_id": room_id},
@@ -138,7 +139,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: json.loads(row["content"]) for row in rows
             }
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
@@ -156,7 +157,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_room_and_type_txn(txn):
-            content_json = self.simple_select_one_onecol_txn(
+            content_json = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="room_account_data",
                 keyvalues={
@@ -170,7 +171,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return json.loads(content_json) if content_json else None
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
@@ -207,7 +208,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             room_results = txn.fetchall()
             return global_results, room_results
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_account_data_txn", get_updated_account_data_txn
         )
 
@@ -250,9 +251,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return defer.succeed(({}, {}))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
@@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore):
 
 
 class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = StreamIdGenerator(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
 
-        super(AccountDataStore, self).__init__(db_conn, hs)
+        super(AccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         """Get the current max stream id for the private user data stream
@@ -302,7 +303,7 @@ class AccountDataStore(AccountDataWorkerStore):
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
-            yield self.simple_upsert(
+            yield self.db.simple_upsert(
                 desc="add_room_account_data",
                 table="room_account_data",
                 keyvalues={
@@ -348,7 +349,7 @@ class AccountDataStore(AccountDataWorkerStore):
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
-            yield self.simple_upsert(
+            yield self.db.simple_upsert(
                 desc="add_user_account_data",
                 table="account_data",
                 keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -388,4 +389,4 @@ class AccountDataStore(AccountDataWorkerStore):
             )
             txn.execute(update_max_id_sql, (next_id, next_id))
 
-        return self.runInteraction("update_account_data_max_stream_id", _update)
+        return self.db.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 6b82fd392a..b2f39649fd 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
@@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache):
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
-        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+        super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
@@ -133,7 +134,7 @@ class ApplicationServiceTransactionWorkerStore(
             A Deferred which resolves to a list of ApplicationServices, which
             may be empty.
         """
-        results = yield self.simple_select_list(
+        results = yield self.db.simple_select_list(
             "application_services_state", dict(state=state), ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
@@ -155,7 +156,7 @@ class ApplicationServiceTransactionWorkerStore(
         Returns:
             A Deferred which resolves to ApplicationServiceState.
         """
-        result = yield self.simple_select_one(
+        result = yield self.db.simple_select_one(
             "application_services_state",
             dict(as_id=service.id),
             ["state"],
@@ -175,7 +176,7 @@ class ApplicationServiceTransactionWorkerStore(
         Returns:
             A Deferred which resolves when the state was set successfully.
         """
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             "application_services_state", dict(as_id=service.id), dict(state=state)
         )
 
@@ -216,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
             )
             return AppServiceTransaction(service=service, id=new_txn_id, events=events)
 
-        return self.runInteraction("create_appservice_txn", _create_appservice_txn)
+        return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
 
     def complete_appservice_txn(self, txn_id, service):
         """Completes an application service transaction.
@@ -249,7 +250,7 @@ class ApplicationServiceTransactionWorkerStore(
                 )
 
             # Set current txn_id for AS to 'txn_id'
-            self.simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 "application_services_state",
                 dict(as_id=service.id),
@@ -257,11 +258,13 @@ class ApplicationServiceTransactionWorkerStore(
             )
 
             # Delete txn
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
             )
 
-        return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
+        return self.db.runInteraction(
+            "complete_appservice_txn", _complete_appservice_txn
+        )
 
     @defer.inlineCallbacks
     def get_oldest_unsent_txn(self, service):
@@ -283,7 +286,7 @@ class ApplicationServiceTransactionWorkerStore(
                 " ORDER BY txn_id ASC LIMIT 1",
                 (service.id,),
             )
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return None
 
@@ -291,7 +294,7 @@ class ApplicationServiceTransactionWorkerStore(
 
             return entry
 
-        entry = yield self.runInteraction(
+        entry = yield self.db.runInteraction(
             "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
         )
 
@@ -321,7 +324,7 @@ class ApplicationServiceTransactionWorkerStore(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
@@ -350,7 +353,7 @@ class ApplicationServiceTransactionWorkerStore(
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.runInteraction(
+        upper_bound, event_ids = yield self.db.runInteraction(
             "get_new_events_for_appservice", get_new_events_for_appservice_txn
         )
 
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index de3256049d..54ed8574c4 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -95,7 +95,7 @@ class CacheInvalidationStore(SQLBaseStore):
             txn.call_after(ctx.__exit__, None, None, None)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="cache_invalidation_stream",
                 values={
@@ -122,7 +122,9 @@ class CacheInvalidationStore(SQLBaseStore):
             txn.execute(sql, (last_id, limit))
             return txn.fetchall()
 
-        return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
+        return self.db.runInteraction(
+            "get_all_updated_caches", get_all_updated_caches_txn
+        )
 
     def get_cache_stream_token(self):
         if self._cache_id_gen:
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 66522a04b7..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -20,7 +20,8 @@ from six import iteritems
 from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage import background_updates
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
@@ -32,41 +33,41 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
+class ClientIpBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_device_index",
             index_name="user_ips_device_id",
             table="user_ips",
             columns=["user_id", "device_id", "last_seen"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_last_seen_index",
             index_name="user_ips_last_seen",
             table="user_ips",
             columns=["user_id", "last_seen"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_last_seen_only_index",
             index_name="user_ips_last_seen_only",
             table="user_ips",
             columns=["last_seen"],
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_analyze", self._analyze_user_ip
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_remove_dupes", self._remove_user_ip_dupes
         )
 
         # Register a unique index
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_device_unique_index",
             index_name="user_ips_user_token_ip_unique_index",
             table="user_ips",
@@ -75,12 +76,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
         )
 
         # Drop the old non-unique index
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
         )
 
         # Update the last seen info in devices.
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "devices_last_seen", self._devices_last_seen_update
         )
 
@@ -91,8 +92,8 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
 
-        yield self.runWithConnection(f)
-        yield self._end_background_update("user_ips_drop_nonunique_index")
+        yield self.db.runWithConnection(f)
+        yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
         return 1
 
     @defer.inlineCallbacks
@@ -106,9 +107,9 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
         def user_ips_analyze(txn):
             txn.execute("ANALYZE user_ips")
 
-        yield self.runInteraction("user_ips_analyze", user_ips_analyze)
+        yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
 
-        yield self._end_background_update("user_ips_analyze")
+        yield self.db.updates._end_background_update("user_ips_analyze")
 
         return 1
 
@@ -140,7 +141,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
                 return None
 
         # Get a last seen that has roughly `batch_size` since `begin_last_seen`
-        end_last_seen = yield self.runInteraction(
+        end_last_seen = yield self.db.runInteraction(
             "user_ips_dups_get_last_seen", get_last_seen
         )
 
@@ -271,14 +272,14 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
                     (user_id, access_token, ip, device_id, user_agent, last_seen),
                 )
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
             )
 
-        yield self.runInteraction("user_ips_dups_remove", remove)
+        yield self.db.runInteraction("user_ips_dups_remove", remove)
 
         if last:
-            yield self._end_background_update("user_ips_remove_dupes")
+            yield self.db.updates._end_background_update("user_ips_remove_dupes")
 
         return batch_size
 
@@ -344,7 +345,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
             txn.execute_batch(sql, rows)
 
             _, _, _, user_id, device_id = rows[-1]
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn,
                 "devices_last_seen",
                 {"last_user_id": user_id, "last_device_id": device_id},
@@ -352,24 +353,24 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
 
             return len(rows)
 
-        updated = yield self.runInteraction(
+        updated = yield self.db.runInteraction(
             "_devices_last_seen_update", _devices_last_seen_update_txn
         )
 
         if not updated:
-            yield self._end_background_update("devices_last_seen")
+            yield self.db.updates._end_background_update("devices_last_seen")
 
         return updated
 
 
 class ClientIpStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
         )
 
-        super(ClientIpStore, self).__init__(db_conn, hs)
+        super(ClientIpStore, self).__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.user_ips_max_age
 
@@ -417,12 +418,12 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         to_update = self._batch_row_update
         self._batch_row_update = {}
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
-        if "user_ips" in self._unsafe_to_upsert_tables or (
+        if "user_ips" in self.db._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
         ):
             self.database_engine.lock_table(txn, "user_ips")
@@ -431,7 +432,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
             try:
-                self.simple_upsert_txn(
+                self.db.simple_upsert_txn(
                     txn,
                     table="user_ips",
                     keyvalues={
@@ -450,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self.simple_upsert_txn(
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
-                        values={
+                        updatevalues={
                             "user_agent": user_agent,
                             "last_seen": last_seen,
                             "ip": ip,
                         },
-                        lock=False,
                     )
             except Exception as e:
                 # Failed to upsert, log and continue
@@ -483,7 +486,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = yield self.simple_select_list(
+        res = yield self.db.simple_select_list(
             table="devices",
             keyvalues=keyvalues,
             retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -516,7 +519,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 user_agent, _, last_seen = self._batch_row_update[key]
                 results[(access_token, ip)] = (user_agent, last_seen)
 
-        rows = yield self.simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="user_ips",
             keyvalues={"user_id": user_id},
             retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -546,7 +549,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             # Nothing to do
             return
 
-        if not await self.has_completed_background_update("devices_last_seen"):
+        if not await self.db.updates.has_completed_background_update(
+            "devices_last_seen"
+        ):
             # Only start pruning if we have finished populating the devices
             # last seen info.
             return
@@ -577,4 +582,4 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         def _prune_old_user_ips_txn(txn):
             txn.execute(sql, (timestamp,))
 
-        await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
+        await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 206d39134d..85cfa16850 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 logger = logging.getLogger(__name__)
@@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_new_messages_for_device", get_new_messages_for_device_txn
         )
 
@@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        count = yield self.runInteraction(
+        count = yield self.db.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
@@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_new_device_msgs_for_remote",
             get_new_messages_for_remote_destination_txn,
         )
@@ -203,25 +203,25 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (destination, up_to_stream_id))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
 
-class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_inbox_stream_index",
             index_name="device_inbox_stream_id_user_id",
             table="device_inbox",
             columns=["stream_id", "user_id"],
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
@@ -232,9 +232,9 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
             txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
             txn.close()
 
-        yield self.runWithConnection(reindex_txn)
+        yield self.db.runWithConnection(reindex_txn)
 
-        yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+        yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
 
         return 1
 
@@ -242,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
@@ -294,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
             )
             for user_id in local_messages_by_user_then_device.keys():
@@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
             # acknowledgement from the first time we received the message.
-            already_inserted = self.simple_select_one_txn(
+            already_inserted = self.db.simple_select_one_txn(
                 txn,
                 table="device_federation_inbox",
                 keyvalues={"origin": origin, "message_id": message_id},
@@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
             # Add an entry for this message_id so that we know we've processed
             # it.
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="device_federation_inbox",
                 values={
@@ -344,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,
                 now_ms,
@@ -465,6 +465,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
             return rows
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 727c582121..9a828231c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -31,7 +31,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.database import Database
 from synapse.types import get_verify_key_from_cross_signing_key
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import (
@@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Raises:
             StoreError: if the device is not found
         """
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore):
             containing "device_id", "user_id" and "display_name" for each
             device.
         """
-        devices = yield self.simple_select_list(
+        devices = yield self.db.simple_select_list(
             table="devices",
             keyvalues={"user_id": user_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -122,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore):
         # consider the device update to be too large, and simply skip the
         # stream_id; the rationale being that such a large device list update
         # is likely an error.
-        updates = yield self.runInteraction(
+        updates = yield self.db.runInteraction(
             "get_device_updates_by_remote",
             self._get_device_updates_by_remote_txn,
             destination,
@@ -283,7 +283,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         """
         devices = (
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_get_e2e_device_keys_txn",
                 self._get_e2e_device_keys_txn,
                 query_map.keys(),
@@ -340,12 +340,12 @@ class DeviceWorkerStore(SQLBaseStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        return self.runInteraction("get_last_device_update_for_remote_user", f)
+        return self.db.runInteraction("get_last_device_update_for_remote_user", f)
 
     def mark_as_sent_devices_by_remote(self, destination, stream_id):
         """Mark that updates have successfully been sent to the destination.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
             destination,
@@ -399,7 +399,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
 
         with self._device_list_id_gen.get_next() as stream_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
                 from_user_id,
@@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore):
             from_user_id,
             stream_id,
         )
-        self.simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             "user_signature_stream",
             values={
@@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=2, tree=True)
     def _get_cached_user_device(self, user_id, device_id):
-        content = yield self.simple_select_one_onecol(
+        content = yield self.db.simple_select_one_onecol(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id, "device_id": device_id},
             retcol="content",
@@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks()
     def _get_cached_devices_for_user(self, user_id):
-        devices = yield self.simple_select_list(
+        devices = yield self.db.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
             retcols=("device_id", "content"),
@@ -492,7 +492,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             (stream_id, devices)
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_devices_with_keys_by_user",
             self._get_devices_with_keys_by_user_txn,
             user_id,
@@ -565,7 +565,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
             return changes
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
@@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 SELECT DISTINCT user_ids FROM user_signature_stream
                 WHERE from_user_id = ? AND stream_id > ?
             """
-            rows = yield self.execute(
+            rows = yield self.db.execute(
                 "get_users_whose_signatures_changed", None, sql, user_id, from_key
             )
             return set(user for row in rows for user in json.loads(row[0]))
@@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore):
             WHERE ? < stream_id AND stream_id <= ?
             GROUP BY user_id, destination
         """
-        return self.execute(
+        return self.db.execute(
             "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
         )
 
@@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             retcol="stream_id",
@@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_device_list_last_stream_id_for_remotes(self, user_ids):
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
             iterable=user_ids,
@@ -642,11 +642,11 @@ class DeviceWorkerStore(SQLBaseStore):
         return results
 
 
-class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
+class DeviceBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_stream_idx",
             index_name="device_lists_stream_user_id",
             table="device_lists_stream",
@@ -654,7 +654,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
         )
 
         # create a unique index on device_lists_remote_cache
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_remote_cache_unique_idx",
             index_name="device_lists_remote_cache_unique_id",
             table="device_lists_remote_cache",
@@ -663,7 +663,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
         )
 
         # And one on device_lists_remote_extremeties
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_remote_extremeties_unique_idx",
             index_name="device_lists_remote_extremeties_unique_idx",
             table="device_lists_remote_extremeties",
@@ -672,7 +672,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
         )
 
         # once they complete, we can remove the old non-unique indexes.
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
             self._drop_device_list_streams_non_unique_indexes,
         )
@@ -685,14 +685,16 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
             txn.close()
 
-        yield self.runWithConnection(f)
-        yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
+        yield self.db.runWithConnection(f)
+        yield self.db.updates._end_background_update(
+            DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
+        )
         return 1
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
@@ -722,7 +724,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return False
 
         try:
-            inserted = yield self.simple_insert(
+            inserted = yield self.db.simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
@@ -736,7 +738,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             if not inserted:
                 # if the device already exists, check if it's a real device, or
                 # if the device ID is reserved by something else
-                hidden = yield self.simple_select_one_onecol(
+                hidden = yield self.db.simple_select_one_onecol(
                     "devices",
                     keyvalues={"user_id": user_id, "device_id": device_id},
                     retcol="hidden",
@@ -771,7 +773,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             defer.Deferred
         """
-        yield self.simple_delete_one(
+        yield self.db.simple_delete_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             desc="delete_device",
@@ -789,7 +791,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             defer.Deferred
         """
-        yield self.simple_delete_many(
+        yield self.db.simple_delete_many(
             table="devices",
             column="device_id",
             iterable=device_ids,
@@ -818,7 +820,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             updates["display_name"] = new_display_name
         if not updates:
             return defer.succeed(None)
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             updatevalues=updates,
@@ -829,7 +831,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def mark_remote_user_device_list_as_unsubscribed(self, user_id):
         """Mark that we no longer track device lists for remote user.
         """
-        yield self.simple_delete(
+        yield self.db.simple_delete(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             desc="mark_remote_user_device_list_as_unsubscribed",
@@ -853,7 +855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             Deferred[None]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_remote_device_list_cache_entry",
             self._update_remote_device_list_cache_entry_txn,
             user_id,
@@ -866,7 +868,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self, txn, user_id, device_id, content, stream_id
     ):
         if content.get("deleted"):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="device_lists_remote_cache",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -874,7 +876,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
             txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
         else:
-            self.simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="device_lists_remote_cache",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -890,7 +892,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
 
-        self.simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
@@ -914,7 +916,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             Deferred[None]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_remote_device_list_cache",
             self._update_remote_device_list_cache_txn,
             user_id,
@@ -923,11 +925,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
-        self.simple_delete_txn(
+        self.db.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
 
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_remote_cache",
             values=[
@@ -946,7 +948,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
 
-        self.simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
@@ -962,7 +964,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         (if any) should be poked.
         """
         with self._device_list_id_gen.get_next() as stream_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_device_change_to_streams",
                 self._add_device_change_txn,
                 user_id,
@@ -995,7 +997,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             [(user_id, device_id, stream_id) for device_id in device_ids],
         )
 
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_stream",
             values=[
@@ -1006,7 +1008,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         context = get_active_span_text_map()
 
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
             values=[
@@ -1069,7 +1071,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         return run_as_background_process(
             "prune_old_outbound_device_pokes",
-            self.runInteraction,
+            self.db.runInteraction,
             "_prune_old_outbound_device_pokes",
             _prune_txn,
         )
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index d332f8a409..c9e7de7d12 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore):
             Deferred: results in namedtuple with keys "room_id" and
             "servers" or None if no association can be found
         """
-        room_id = yield self.simple_select_one_onecol(
+        room_id = yield self.db.simple_select_one_onecol(
             "room_aliases",
             {"room_alias": room_alias.to_string()},
             "room_id",
@@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         if not room_id:
             return None
 
-        servers = yield self.simple_select_onecol(
+        servers = yield self.db.simple_select_onecol(
             "room_alias_servers",
             {"room_alias": room_alias.to_string()},
             "server",
@@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
     def get_room_alias_creator(self, room_alias):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="room_aliases",
             keyvalues={"room_alias": room_alias},
             retcol="creator",
@@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
 
     @cached(max_entries=5000)
     def get_aliases_for_room(self, room_id):
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
             "room_alias",
@@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore):
         """
 
         def alias_txn(txn):
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 "room_aliases",
                 {
@@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore):
                 },
             )
 
-            self.simple_insert_many_txn(
+            self.db.simple_insert_many_txn(
                 txn,
                 table="room_alias_servers",
                 values=[
@@ -117,7 +117,9 @@ class DirectoryStore(DirectoryWorkerStore):
             )
 
         try:
-            ret = yield self.runInteraction("create_room_alias_association", alias_txn)
+            ret = yield self.db.runInteraction(
+                "create_room_alias_association", alias_txn
+            )
         except self.database_engine.module.IntegrityError:
             raise SynapseError(
                 409, "Room alias %s already exists" % room_alias.to_string()
@@ -126,7 +128,7 @@ class DirectoryStore(DirectoryWorkerStore):
 
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
-        room_id = yield self.runInteraction(
+        room_id = yield self.db.runInteraction(
             "delete_room_alias", self._delete_room_alias_txn, room_alias
         )
 
@@ -168,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore):
                 txn, self.get_aliases_for_room, (new_room_id,)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_update_aliases_for_room_txn", _update_aliases_for_room_txn
         )
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index df89eda337..84594cf0a9 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             StoreError
         """
 
-        yield self.simple_update_one(
+        yield self.db.simple_update_one(
             table="e2e_room_keys",
             keyvalues={
                 "user_id": user_id,
@@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 }
             )
 
-        yield self.simple_insert_many(
+        yield self.db.simple_insert_many(
             table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
         )
 
@@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        rows = yield self.simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="e2e_room_keys",
             keyvalues=keyvalues,
             retcols=(
@@ -170,7 +170,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
            Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_e2e_room_keys_multi",
             self._get_e2e_room_keys_multi_txn,
             user_id,
@@ -234,7 +234,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             version (str): the version ID of the backup we're querying about
         """
 
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="e2e_room_keys",
             keyvalues={"user_id": user_id, "version": version},
             retcol="COUNT(*)",
@@ -267,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        yield self.simple_delete(
+        yield self.db.simple_delete(
             table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
         )
 
@@ -312,7 +312,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     # it isn't there.
                     raise StoreError(404, "No row found")
 
-            result = self.simple_select_one_txn(
+            result = self.db.simple_select_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
@@ -324,7 +324,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 result["etag"] = 0
             return result
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
         )
 
@@ -352,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
             new_version = str(int(current_version) + 1)
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 values={
@@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
             return new_version
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
         )
 
@@ -391,7 +391,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             updatevalues["etag"] = version_etag
 
         if updatevalues:
-            return self.simple_update(
+            return self.db.simple_update(
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": version},
                 updatevalues=updatevalues,
@@ -420,19 +420,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             else:
                 this_version = version
 
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="e2e_room_keys",
                 keyvalues={"user_id": user_id, "version": this_version},
             )
 
-            return self.simple_update_one_txn(
+            return self.db.simple_update_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version},
                 updatevalues={"deleted": 1},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
         )
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 08bcdc4725..38cd0ca9b8 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -48,7 +48,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         if not query_list:
             return {}
 
-        results = yield self.runInteraction(
+        results = yield self.db.runInteraction(
             "get_e2e_device_keys",
             self._get_e2e_device_keys_txn,
             query_list,
@@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, query_params)
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
 
         result = {}
         for row in rows:
@@ -143,7 +143,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
         txn.execute(signature_sql, signature_query_params)
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
 
         # add each cross-signing signature to the correct device in the result dict.
         for row in rows:
@@ -186,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             key_id) to json string for key
         """
 
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="e2e_one_time_keys_json",
             column="key_id",
             iterable=key_ids,
@@ -219,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             # a unique constraint. If there is a race of two calls to
             # `add_e2e_one_time_keys` then they'll conflict and we will only
             # insert one set.
-            self.simple_insert_many_txn(
+            self.db.simple_insert_many_txn(
                 txn,
                 table="e2e_one_time_keys_json",
                 values=[
@@ -238,7 +238,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
@@ -261,7 +261,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 result[algorithm] = key_count
             return result
 
-        return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
+        return self.db.runInteraction(
+            "count_e2e_one_time_keys", _count_e2e_one_time_keys
+        )
 
     def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
         """Returns a user's cross-signing key.
@@ -322,7 +324,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Returns:
             dict of the key data or None if not found
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_e2e_cross_signing_key",
             self._get_e2e_cross_signing_key_txn,
             user_id,
@@ -350,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             WHERE ? < stream_id AND stream_id <= ?
             GROUP BY user_id
         """
-        return self.execute(
+        return self.db.execute(
             "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
         )
 
@@ -367,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             set_tag("time_now", time_now)
             set_tag("device_keys", device_keys)
 
-            old_key_json = self.simple_select_one_onecol_txn(
+            old_key_json = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="e2e_device_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -383,7 +385,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 log_kv({"Message": "Device key already stored."})
                 return False
 
-            self.simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="e2e_device_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -392,7 +394,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             log_kv({"message": "Device keys stored."})
             return True
 
-        return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+        return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
 
     def claim_e2e_one_time_keys(self, query_list):
         """Take a list of one time keys out of the database"""
@@ -431,7 +433,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 )
             return result
 
-        return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
+        return self.db.runInteraction(
+            "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+        )
 
     def delete_e2e_keys_by_device(self, user_id, device_id):
         def delete_e2e_keys_by_device_txn(txn):
@@ -442,12 +446,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                     "user_id": user_id,
                 }
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="e2e_device_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="e2e_one_time_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -456,7 +460,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
@@ -492,7 +496,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         # The "keys" property must only have one entry, which will be the public
         # key, so we just grab the first value in there
         pubkey = next(iter(key["keys"].values()))
-        self.simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             "devices",
             values={
@@ -505,7 +509,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
 
         # and finally, store the key itself
         with self._cross_signing_id_gen.get_next() as stream_id:
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 "e2e_cross_signing_keys",
                 values={
@@ -524,7 +528,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key_type (str): the type of cross-signing key to set
             key (dict): the key data
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_e2e_cross_signing_key",
             self._set_e2e_cross_signing_key_txn,
             user_id,
@@ -539,7 +543,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             user_id (str): the user who made the signatures
             signatures (iterable[SignatureListItem]): signatures to add
         """
-        return self.simple_insert_many(
+        return self.db.simple_insert_many(
             "e2e_cross_signing_signatures",
             [
                 {
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 051ac7a8cb..1f517e8fad 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -58,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             list of event_ids
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
         )
 
@@ -90,12 +91,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         return list(results)
 
     def get_oldest_events_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
         )
 
     def get_oldest_events_with_depth_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_oldest_events_with_depth_in_room",
             self.get_oldest_events_with_depth_in_room_txn,
             room_id,
@@ -126,7 +127,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns
             Deferred[int]
         """
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="events",
             column="event_id",
             iterable=event_ids,
@@ -140,7 +141,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             return max(row["depth"] for row in rows)
 
     def _get_oldest_events_in_room_txn(self, txn, room_id):
-        return self.simple_select_onecol_txn(
+        return self.db.simple_select_onecol_txn(
             txn,
             table="event_backward_extremities",
             keyvalues={"room_id": room_id},
@@ -188,7 +189,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 where *hashes* is a map from algorithm to hash.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_latest_event_ids_and_hashes_in_room",
             self._get_latest_event_ids_and_hashes_in_room,
             room_id,
@@ -229,13 +230,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, query_args)
             return [room_id for room_id, in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
         )
 
     @cached(max_entries=5000, iterable=True)
     def get_latest_event_ids_in_room(self, room_id):
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
             retcol="event_id",
@@ -266,12 +267,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
     def get_min_depth(self, room_id):
         """ For hte given room, get the minimum depth we have seen for it.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
     def _get_min_depth_interaction(self, txn, room_id):
-        min_depth = self.simple_select_one_onecol_txn(
+        min_depth = self.db.simple_select_one_onecol_txn(
             txn,
             table="room_depth",
             keyvalues={"room_id": room_id},
@@ -337,7 +338,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
@@ -352,7 +353,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             limit (int)
         """
         return (
-            self.runInteraction(
+            self.db.runInteraction(
                 "get_backfill_events",
                 self._get_backfill_events,
                 room_id,
@@ -383,7 +384,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         queue = PriorityQueue()
 
         for event_id in event_list:
-            depth = self.simple_select_one_onecol_txn(
+            depth = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="events",
                 keyvalues={"event_id": event_id, "room_id": room_id},
@@ -415,7 +416,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
     @defer.inlineCallbacks
     def get_missing_events(self, room_id, earliest_events, latest_events, limit):
-        ids = yield self.runInteraction(
+        ids = yield self.db.runInteraction(
             "get_missing_events",
             self._get_missing_events,
             room_id,
@@ -468,7 +469,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             Deferred[list[str]]
         """
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="event_edges",
             column="prev_event_id",
             iterable=event_ids,
@@ -491,10 +492,10 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, db_conn, hs):
-        super(EventFederationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventFederationStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
@@ -508,7 +509,7 @@ class EventFederationStore(EventFederationWorkerStore):
         if min_depth and depth >= min_depth:
             return
 
-        self.simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="room_depth",
             keyvalues={"room_id": room_id},
@@ -520,7 +521,7 @@ class EventFederationStore(EventFederationWorkerStore):
         For the given event, update the event edges table and forward and
         backward extremities tables.
         """
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_edges",
             values=[
@@ -604,13 +605,13 @@ class EventFederationStore(EventFederationWorkerStore):
 
         return run_as_background_process(
             "delete_old_forward_extrem_cache",
-            self.runInteraction,
+            self.db.runInteraction,
             "_delete_old_forward_extrem_cache",
             _delete_old_forward_extrem_cache_txn,
         )
 
     def clean_room_for_join(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )
 
@@ -654,17 +655,17 @@ class EventFederationStore(EventFederationWorkerStore):
                 "max_stream_id_exclusive": min_stream_id,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_AUTH_STATE_ONLY, new_progress
             )
 
             return min_stream_id >= target_min_stream_id
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_AUTH_STATE_ONLY, delete_event_auth
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+            yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
 
         return batch_size
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 0a37847cfd..9988a6d3fc 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
         self.stream_ordering_month_ago = None
@@ -93,7 +94,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def get_unread_event_push_actions_by_room_for_user(
         self, room_id, user_id, last_read_event_id
     ):
-        ret = yield self.runInteraction(
+        ret = yield self.db.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
             room_id,
@@ -177,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
 
-        ret = yield self.runInteraction("get_push_action_users_in_range", f)
+        ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
         return ret
 
     @defer.inlineCallbacks
@@ -229,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.runInteraction(
+        after_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
         )
 
@@ -257,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.runInteraction(
+        no_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
         )
 
@@ -329,7 +330,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.runInteraction(
+        after_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
         )
 
@@ -357,7 +358,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.runInteraction(
+        no_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
         )
 
@@ -407,7 +408,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, min_stream_ordering))
             return bool(txn.fetchone())
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_if_maybe_push_in_range_for_user",
             _get_if_maybe_push_in_range_for_user_txn,
         )
@@ -458,7 +459,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 ),
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_push_actions_to_staging", _add_push_actions_to_staging_txn
         )
 
@@ -472,7 +473,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
 
         try:
-            res = yield self.simple_delete(
+            res = yield self.db.simple_delete(
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
@@ -489,7 +490,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def _find_stream_orderings_for_times(self):
         return run_as_background_process(
             "event_push_action_stream_orderings",
-            self.runInteraction,
+            self.db.runInteraction,
             "_find_stream_orderings_for_times",
             self._find_stream_orderings_for_times_txn,
         )
@@ -525,7 +526,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             Deferred[int]: stream ordering of the first event received on/after
                 the timestamp
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_find_first_stream_ordering_after_ts_txn",
             self._find_first_stream_ordering_after_ts_txn,
             ts,
@@ -611,17 +612,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             self.EPA_HIGHLIGHT_INDEX,
             index_name="event_push_actions_u_highlight",
             table="event_push_actions",
             columns=["user_id", "stream_ordering"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_push_actions_highlights_index",
             index_name="event_push_actions_highlights_index",
             table="event_push_actions",
@@ -677,7 +678,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             )
 
         for event, _ in events_and_contexts:
-            user_ids = self.simple_select_onecol_txn(
+            user_ids = self.db.simple_select_onecol_txn(
                 txn,
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event.event_id},
@@ -727,9 +728,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " LIMIT ?" % (before_clause,)
             )
             txn.execute(sql, args)
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        push_actions = yield self.runInteraction("get_push_actions_for_user", f)
+        push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
         for pa in push_actions:
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
@@ -748,7 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             txn.execute(sql, (stream_ordering,))
             return txn.fetchone()
 
-        result = yield self.runInteraction("get_time_of_last_push_action_before", f)
+        result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
         return result[0] if result else None
 
     @defer.inlineCallbacks
@@ -757,7 +758,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
             return txn.fetchone()
 
-        result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
+        result = yield self.db.runInteraction(
+            "get_latest_push_action_stream_ordering", f
+        )
         return result[0] or 0
 
     def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
@@ -830,7 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             while True:
                 logger.info("Rotating notifications")
 
-                caught_up = yield self.runInteraction(
+                caught_up = yield self.db.runInteraction(
                     "_rotate_notifs", self._rotate_notifs_txn
                 )
                 if caught_up:
@@ -844,7 +847,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         the archiving process has caught up or not.
         """
 
-        old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
+        old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
             keyvalues={},
@@ -880,7 +883,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         return caught_up
 
     def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
-        old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
+        old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
             keyvalues={},
@@ -912,7 +915,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
         # existing table.
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_push_summary",
             values=[
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 98ae69e996..998bba1aad 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -38,10 +38,10 @@ from synapse.logging.utils import log_function
 from synapse.metrics import BucketCollector
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.event_federation import EventFederationStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.database import Database
 from synapse.types import RoomStreamToken, get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -94,13 +94,10 @@ def _retry_on_integrity_error(func):
 # inherits from EventFederationStore so that we can call _update_backward_extremities
 # and _handle_mult_prev_events (though arguably those could both be moved in here)
 class EventsStore(
-    StateGroupWorkerStore,
-    EventFederationStore,
-    EventsWorkerStore,
-    BackgroundUpdateStore,
+    StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(EventsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsStore, self).__init__(database, db_conn, hs)
 
         # Collect metrics on the number of forward extremities that exist.
         # Counter of number of extremities to count
@@ -143,7 +140,7 @@ class EventsStore(
             )
             return txn.fetchall()
 
-        res = yield self.runInteraction("read_forward_extremities", fetch)
+        res = yield self.db.runInteraction("read_forward_extremities", fetch)
         self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
 
     @_retry_on_integrity_error
@@ -208,7 +205,7 @@ class EventsStore(
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
@@ -281,7 +278,7 @@ class EventsStore(
             results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
@@ -345,7 +342,7 @@ class EventsStore(
                         existing_prevs.add(prev_event_id)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )
 
@@ -432,7 +429,7 @@ class EventsStore(
         # event's auth chain, but its easier for now just to store them (and
         # it doesn't take much storage compared to storing the entire event
         # anyway).
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_auth",
             values=[
@@ -580,12 +577,12 @@ class EventsStore(
         self, txn, new_forward_extremities, max_stream_order
     ):
         for room_id, new_extrem in iteritems(new_forward_extremities):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
             )
             txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
             values=[
@@ -598,7 +595,7 @@ class EventsStore(
         # new stream_ordering to new forward extremeties in the room.
         # This allows us to later efficiently look up the forward extremeties
         # for a room before a given stream_ordering
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="stream_ordering_to_exterm",
             values=[
@@ -722,7 +719,7 @@ class EventsStore(
                 # change in outlier status to our workers.
                 stream_order = event.internal_metadata.stream_ordering
                 state_group_id = context.state_group
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="ex_outlier_stream",
                     values={
@@ -794,7 +791,7 @@ class EventsStore(
             d.pop("redacted_because", None)
             return d
 
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_json",
             values=[
@@ -811,7 +808,7 @@ class EventsStore(
             ],
         )
 
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="events",
             values=[
@@ -841,7 +838,7 @@ class EventsStore(
                 # If we're persisting an unredacted event we go and ensure
                 # that we mark any redactions that reference this event as
                 # requiring censoring.
-                self.simple_update_txn(
+                self.db.simple_update_txn(
                     txn,
                     table="redactions",
                     keyvalues={"redacts": event.event_id},
@@ -983,7 +980,7 @@ class EventsStore(
 
             state_values.append(vals)
 
-        self.simple_insert_many_txn(txn, table="state_events", values=state_values)
+        self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
 
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
@@ -1014,7 +1011,7 @@ class EventsStore(
             )
 
             txn.execute(sql + clause, args)
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             for row in rows:
                 event = ev_map[row["event_id"]]
                 if not row["rejects"] and not row["redacts"]:
@@ -1032,7 +1029,7 @@ class EventsStore(
         # invalidate the cache for the redacted event
         txn.call_after(self._invalidate_get_event_cache, event.redacts)
 
-        self.simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="redactions",
             values={
@@ -1042,20 +1039,25 @@ class EventsStore(
             },
         )
 
-    @defer.inlineCallbacks
-    def _censor_redactions(self):
+    async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
 
         By censor we mean update the event_json table with the redacted event.
-
-        Returns:
-            Deferred
         """
 
         if self.hs.config.redaction_retention_period is None:
             return
 
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
         before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
 
         # We fetch all redactions that:
@@ -1077,13 +1079,15 @@ class EventsStore(
             LIMIT ?
         """
 
-        rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100)
+        rows = await self.db.execute(
+            "_censor_redactions_fetch", None, sql, before_ts, 100
+        )
 
         updates = []
 
         for redaction_id, event_id in rows:
-            redaction_event = yield self.get_event(redaction_id, allow_none=True)
-            original_event = yield self.get_event(
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
                 event_id, allow_rejected=True, allow_none=True
             )
 
@@ -1109,14 +1113,14 @@ class EventsStore(
                 if pruned_json:
                     self._censor_event_txn(txn, event_id, pruned_json)
 
-                self.simple_update_one_txn(
+                self.db.simple_update_one_txn(
                     txn,
                     table="redactions",
                     keyvalues={"event_id": redaction_id},
                     updatevalues={"have_censored": True},
                 )
 
-        yield self.runInteraction("_update_censor_txn", _update_censor_txn)
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
 
     def _censor_event_txn(self, txn, event_id, pruned_json):
         """Censor an event by replacing its JSON in the event_json table with the
@@ -1127,7 +1131,7 @@ class EventsStore(
             event_id (str): The ID of the event to censor.
             pruned_json (str): The pruned JSON
         """
-        self.simple_update_one_txn(
+        self.db.simple_update_one_txn(
             txn,
             table="event_json",
             keyvalues={"event_id": event_id},
@@ -1153,7 +1157,7 @@ class EventsStore(
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_messages", _count_messages)
+        ret = yield self.db.runInteraction("count_messages", _count_messages)
         return ret
 
     @defer.inlineCallbacks
@@ -1174,7 +1178,7 @@ class EventsStore(
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
+        ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
         return ret
 
     @defer.inlineCallbacks
@@ -1189,7 +1193,7 @@ class EventsStore(
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_daily_active_rooms", _count)
+        ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
         return ret
 
     def get_current_backfill_token(self):
@@ -1241,7 +1245,7 @@ class EventsStore(
 
             return new_event_updates
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
         )
 
@@ -1286,7 +1290,7 @@ class EventsStore(
 
             return new_event_updates
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
         )
 
@@ -1379,7 +1383,7 @@ class EventsStore(
                 backward_ex_outliers,
             )
 
-        return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+        return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
 
     def purge_history(self, room_id, token, delete_local_events):
         """Deletes room history before a certain point
@@ -1399,7 +1403,7 @@ class EventsStore(
             deleted events.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "purge_history",
             self._purge_history_txn,
             room_id,
@@ -1647,7 +1651,7 @@ class EventsStore(
             Deferred[List[int]]: The list of state groups to delete.
         """
 
-        return self.runInteraction("purge_room", self._purge_room_txn, room_id)
+        return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
 
     def _purge_room_txn(self, txn, room_id):
         # First we fetch all the state groups that should be deleted, before
@@ -1766,7 +1770,7 @@ class EventsStore(
                 to delete.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "purge_unreferenced_state_groups",
             self._purge_unreferenced_state_groups,
             room_id,
@@ -1778,7 +1782,7 @@ class EventsStore(
             "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
 
-        rows = self.simple_select_many_txn(
+        rows = self.db.simple_select_many_txn(
             txn,
             table="state_group_edges",
             column="prev_state_group",
@@ -1805,15 +1809,15 @@ class EventsStore(
             curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
             curr_state = curr_state[sg]
 
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="state_groups_state", keyvalues={"state_group": sg}
             )
 
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="state_group_edges", keyvalues={"state_group": sg}
             )
 
-            self.simple_insert_many_txn(
+            self.db.simple_insert_many_txn(
                 txn,
                 table="state_groups_state",
                 values=[
@@ -1850,7 +1854,7 @@ class EventsStore(
             state group.
         """
 
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="state_group_edges",
             column="prev_state_group",
             iterable=state_groups,
@@ -1869,7 +1873,7 @@ class EventsStore(
             state_groups_to_delete (list[int]): State groups to delete
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "purge_room_state",
             self._purge_room_state_txn,
             room_id,
@@ -1880,7 +1884,7 @@ class EventsStore(
         # first we have to delete the state groups states
         logger.info("[purge] removing %s from state_groups_state", room_id)
 
-        self.simple_delete_many_txn(
+        self.db.simple_delete_many_txn(
             txn,
             table="state_groups_state",
             column="state_group",
@@ -1891,7 +1895,7 @@ class EventsStore(
         # ... and the state group edges
         logger.info("[purge] removing %s from state_group_edges", room_id)
 
-        self.simple_delete_many_txn(
+        self.db.simple_delete_many_txn(
             txn,
             table="state_group_edges",
             column="state_group",
@@ -1902,7 +1906,7 @@ class EventsStore(
         # ... and the state groups
         logger.info("[purge] removing %s from state_groups", room_id)
 
-        self.simple_delete_many_txn(
+        self.db.simple_delete_many_txn(
             txn,
             table="state_groups",
             column="id",
@@ -1919,7 +1923,7 @@ class EventsStore(
 
     @cachedInlineCallbacks(max_entries=5000)
     def _get_event_ordering(self, event_id):
-        res = yield self.simple_select_one(
+        res = yield self.db.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
@@ -1942,7 +1946,7 @@ class EventsStore(
             txn.execute(sql, (from_token, to_token, limit))
             return txn.fetchall()
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
@@ -1960,7 +1964,7 @@ class EventsStore(
             room_id (str): The ID of the room the event was sent to.
             topological_ordering (int): The position of the event in the room's topology.
         """
-        return self.simple_insert_many_txn(
+        return self.db.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             values=[
@@ -1982,7 +1986,7 @@ class EventsStore(
             event_id (str): The event ID the expiry timestamp is associated with.
             expiry_ts (int): The timestamp at which to expire (delete) the event.
         """
-        return self.simple_insert_txn(
+        return self.db.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
@@ -2031,7 +2035,7 @@ class EventsStore(
                 txn, "_get_event_cache", (event.event_id,)
             )
 
-        yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
+        yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
 
     def _delete_event_expiry_txn(self, txn, event_id):
         """Delete the expiry timestamp associated with an event ID without deleting the
@@ -2041,7 +2045,7 @@ class EventsStore(
             txn (LoggingTransaction): The transaction to use to perform the deletion.
             event_id (str): The event ID to delete the associated expiry timestamp of.
         """
-        return self.simple_delete_txn(
+        return self.db.simple_delete_txn(
             txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
         )
 
@@ -2065,7 +2069,7 @@ class EventsStore(
 
             return txn.fetchone()
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
 
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 37dfc8c871..5177b71016 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -22,30 +22,30 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.constants import EventContentFields
-from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
 
-class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
+class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
     DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
 
-    def __init__(self, db_conn, hs):
-        super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
             self._background_reindex_fields_sender,
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_contains_url_index",
             index_name="event_contains_url_index",
             table="events",
@@ -56,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
         # an event_id index on event_search is useful for the purge_history
         # api. Plus it means we get to enforce some integrity with a UNIQUE
         # clause
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_search_event_id_idx",
             index_name="event_search_event_id_idx",
             table="event_search",
@@ -65,16 +65,16 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
             psql_only=True,
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "redactions_received_ts", self._redactions_received_ts
         )
 
         # This index gets deleted in `event_fix_redactions_bytes` update
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_fix_redactions_bytes_create_index",
             index_name="redactions_censored_redacts",
             table="redactions",
@@ -82,14 +82,22 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
             where_clause="have_censored",
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "event_fix_redactions_bytes", self._event_fix_redactions_bytes
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "event_store_labels", self._event_store_labels
         )
 
+        self.db.updates.register_background_index_update(
+            "redactions_have_censored_ts_idx",
+            index_name="redactions_have_censored_ts",
+            table="redactions",
+            columns=["received_ts"],
+            where_clause="NOT have_censored",
+        )
+
     @defer.inlineCallbacks
     def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -145,18 +153,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(rows),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
             )
 
             return len(rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
+            )
 
         return result
 
@@ -189,7 +199,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
             for chunk in chunks:
-                ev_rows = self.simple_select_many_txn(
+                ev_rows = self.db.simple_select_many_txn(
                     txn,
                     table="event_json",
                     column="event_id",
@@ -222,18 +232,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(rows_to_update),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
             )
 
             return len(rows_to_update)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_ORIGIN_SERVER_TS_NAME
+            )
 
         return result
 
@@ -366,7 +378,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             to_delete.intersection_update(original_set)
 
-            deleted = self.simple_delete_many_txn(
+            deleted = self.db.simple_delete_many_txn(
                 txn=txn,
                 table="event_forward_extremities",
                 column="event_id",
@@ -382,7 +394,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             if deleted:
                 # We now need to invalidate the caches of these rooms
-                rows = self.simple_select_many_txn(
+                rows = self.db.simple_select_many_txn(
                     txn,
                     table="events",
                     column="event_id",
@@ -396,7 +408,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                         self.get_latest_event_ids_in_room.invalidate, (room_id,)
                     )
 
-            self.simple_delete_many_txn(
+            self.db.simple_delete_many_txn(
                 txn=txn,
                 table="_extremities_to_check",
                 column="event_id",
@@ -406,17 +418,19 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             return len(original_set)
 
-        num_handled = yield self.runInteraction(
+        num_handled = yield self.db.runInteraction(
             "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
         )
 
         if not num_handled:
-            yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
+            yield self.db.updates._end_background_update(
+                self.DELETE_SOFT_FAILED_EXTREMITIES
+            )
 
             def _drop_table_txn(txn):
                 txn.execute("DROP TABLE _extremities_to_check")
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
             )
 
@@ -464,18 +478,18 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "redactions_received_ts", {"last_event_id": upper_event_id}
             )
 
             return len(rows)
 
-        count = yield self.runInteraction(
+        count = yield self.db.runInteraction(
             "_redactions_received_ts", _redactions_received_ts_txn
         )
 
         if not count:
-            yield self._end_background_update("redactions_received_ts")
+            yield self.db.updates._end_background_update("redactions_received_ts")
 
         return count
 
@@ -501,11 +515,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             txn.execute("DROP INDEX redactions_censored_redacts")
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
         )
 
-        yield self._end_background_update("event_fix_redactions_bytes")
+        yield self.db.updates._end_background_update("event_fix_redactions_bytes")
 
         return 1
 
@@ -533,7 +547,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 try:
                     event_json = json.loads(event_json_raw)
 
-                    self.simple_insert_many_txn(
+                    self.db.simple_insert_many_txn(
                         txn=txn,
                         table="event_labels",
                         values=[
@@ -559,17 +573,17 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 nbrows += 1
                 last_row_event_id = event_id
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "event_store_labels", {"last_event_id": last_row_event_id}
             )
 
             return nbrows
 
-        num_rows = yield self.runInteraction(
+        num_rows = yield self.db.runInteraction(
             desc="event_store_labels", func=_event_store_labels_txn
         )
 
         if not num_rows:
-            yield self._end_background_update("event_store_labels")
+            yield self.db.updates._end_background_update("event_store_labels")
 
         return num_rows
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 6a08a746b6..9ee117ce0f 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -33,6 +33,7 @@ from synapse.events.utils import prune_event
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import Cache
@@ -55,8 +56,8 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
 class EventsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._get_event_cache = Cache(
             "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
@@ -78,7 +79,7 @@ class EventsWorkerStore(SQLBaseStore):
             Deferred[int|None]: Timestamp in milliseconds, or None for events
             that were persisted before received_ts was implemented.
         """
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="events",
             keyvalues={"event_id": event_id},
             retcol="received_ts",
@@ -117,7 +118,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             return ts
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_approximate_received_ts", _get_approximate_received_ts_txn
         )
 
@@ -452,7 +453,7 @@ class EventsWorkerStore(SQLBaseStore):
                     event_id for events, _ in event_list for event_id in events
                 )
 
-                row_dict = self.new_transaction(
+                row_dict = self.db.new_transaction(
                     conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
                 )
 
@@ -584,7 +585,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         if should_start:
             run_as_background_process(
-                "fetch_events", self.runWithConnection, self._do_fetch
+                "fetch_events", self.db.runWithConnection, self._do_fetch
             )
 
         logger.debug("Loading %d events: %s", len(events), events)
@@ -745,7 +746,7 @@ class EventsWorkerStore(SQLBaseStore):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="events",
             retcols=("event_id",),
             column="event_id",
@@ -780,7 +781,9 @@ class EventsWorkerStore(SQLBaseStore):
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
+            yield self.db.runInteraction(
+                "have_seen_events", have_seen_events_txn, chunk
+            )
         return results
 
     def _get_total_state_event_counts_txn(self, txn, room_id):
@@ -807,7 +810,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred[int]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_total_state_event_counts",
             self._get_total_state_event_counts_txn,
             room_id,
@@ -832,7 +835,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred[int]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_current_state_event_counts",
             self._get_current_state_event_counts_txn,
             room_id,
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py
index 17ef7b9354..342d6622a4 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
         except ValueError:
             raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
 
-        def_json = yield self.simple_select_one_onecol(
+        def_json = yield self.db.simple_select_one_onecol(
             table="user_filters",
             keyvalues={"user_id": user_localpart, "filter_id": filter_id},
             retcol="filter_json",
@@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore):
 
             return filter_id
 
-        return self.runInteraction("add_user_filter", _do_txn)
+        return self.db.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 9e1d12bcb7..6acd45e9f3 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore):
          * "invite"
          * "open"
         """
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues={"join_policy": join_policy},
@@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def get_group(self, group_id):
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             table="groups",
             keyvalues={"group_id": group_id},
             retcols=(
@@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore):
         if not include_private:
             keyvalues["is_public"] = True
 
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             table="group_users",
             keyvalues=keyvalues,
             retcols=("user_id", "is_public", "is_admin"),
@@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore):
     def get_invited_users_in_group(self, group_id):
         # TODO: Pagination
 
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id},
             retcol="user_id",
@@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore):
         if not include_private:
             keyvalues["is_public"] = True
 
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             table="group_rooms",
             keyvalues=keyvalues,
             retcols=("room_id", "is_public"),
@@ -153,10 +153,12 @@ class GroupServerStore(SQLBaseStore):
 
             return rooms, categories
 
-        return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
+        return self.db.runInteraction(
+            "get_rooms_for_summary", _get_rooms_for_summary_txn
+        )
 
     def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_room_to_summary",
             self._add_room_to_summary_txn,
             group_id,
@@ -180,7 +182,7 @@ class GroupServerStore(SQLBaseStore):
                 an order of 1 will put the room first. Otherwise, the room gets
                 added to the end.
         """
-        room_in_group = self.simple_select_one_onecol_txn(
+        room_in_group = self.db.simple_select_one_onecol_txn(
             txn,
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
@@ -193,7 +195,7 @@ class GroupServerStore(SQLBaseStore):
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
         else:
-            cat_exists = self.simple_select_one_onecol_txn(
+            cat_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_room_categories",
                 keyvalues={"group_id": group_id, "category_id": category_id},
@@ -204,7 +206,7 @@ class GroupServerStore(SQLBaseStore):
                 raise SynapseError(400, "Category doesn't exist")
 
             # TODO: Check category is part of summary already
-            cat_exists = self.simple_select_one_onecol_txn(
+            cat_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_summary_room_categories",
                 keyvalues={"group_id": group_id, "category_id": category_id},
@@ -224,7 +226,7 @@ class GroupServerStore(SQLBaseStore):
                     (group_id, category_id, group_id, category_id),
                 )
 
-        existing = self.simple_select_one_txn(
+        existing = self.db.simple_select_one_txn(
             txn,
             table="group_summary_rooms",
             keyvalues={
@@ -257,7 +259,7 @@ class GroupServerStore(SQLBaseStore):
                 to_update["room_order"] = order
             if is_public is not None:
                 to_update["is_public"] = is_public
-            self.simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="group_summary_rooms",
                 keyvalues={
@@ -271,7 +273,7 @@ class GroupServerStore(SQLBaseStore):
             if is_public is None:
                 is_public = True
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_summary_rooms",
                 values={
@@ -287,7 +289,7 @@ class GroupServerStore(SQLBaseStore):
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
 
-        return self.simple_delete(
+        return self.db.simple_delete(
             table="group_summary_rooms",
             keyvalues={
                 "group_id": group_id,
@@ -299,7 +301,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_categories(self, group_id):
-        rows = yield self.simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="group_room_categories",
             keyvalues={"group_id": group_id},
             retcols=("category_id", "is_public", "profile"),
@@ -316,7 +318,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_category(self, group_id, category_id):
-        category = yield self.simple_select_one(
+        category = yield self.db.simple_select_one(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             retcols=("is_public", "profile"),
@@ -343,7 +345,7 @@ class GroupServerStore(SQLBaseStore):
         else:
             update_values["is_public"] = is_public
 
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             values=update_values,
@@ -352,7 +354,7 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def remove_group_category(self, group_id, category_id):
-        return self.simple_delete(
+        return self.db.simple_delete(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             desc="remove_group_category",
@@ -360,7 +362,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_roles(self, group_id):
-        rows = yield self.simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="group_roles",
             keyvalues={"group_id": group_id},
             retcols=("role_id", "is_public", "profile"),
@@ -377,7 +379,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_role(self, group_id, role_id):
-        role = yield self.simple_select_one(
+        role = yield self.db.simple_select_one(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             retcols=("is_public", "profile"),
@@ -404,7 +406,7 @@ class GroupServerStore(SQLBaseStore):
         else:
             update_values["is_public"] = is_public
 
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             values=update_values,
@@ -413,14 +415,14 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def remove_group_role(self, group_id, role_id):
-        return self.simple_delete(
+        return self.db.simple_delete(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             desc="remove_group_role",
         )
 
     def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_user_to_summary",
             self._add_user_to_summary_txn,
             group_id,
@@ -444,7 +446,7 @@ class GroupServerStore(SQLBaseStore):
                 an order of 1 will put the user first. Otherwise, the user gets
                 added to the end.
         """
-        user_in_group = self.simple_select_one_onecol_txn(
+        user_in_group = self.db.simple_select_one_onecol_txn(
             txn,
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -457,7 +459,7 @@ class GroupServerStore(SQLBaseStore):
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
         else:
-            role_exists = self.simple_select_one_onecol_txn(
+            role_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_roles",
                 keyvalues={"group_id": group_id, "role_id": role_id},
@@ -468,7 +470,7 @@ class GroupServerStore(SQLBaseStore):
                 raise SynapseError(400, "Role doesn't exist")
 
             # TODO: Check role is part of the summary already
-            role_exists = self.simple_select_one_onecol_txn(
+            role_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_summary_roles",
                 keyvalues={"group_id": group_id, "role_id": role_id},
@@ -488,7 +490,7 @@ class GroupServerStore(SQLBaseStore):
                     (group_id, role_id, group_id, role_id),
                 )
 
-        existing = self.simple_select_one_txn(
+        existing = self.db.simple_select_one_txn(
             txn,
             table="group_summary_users",
             keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -517,7 +519,7 @@ class GroupServerStore(SQLBaseStore):
                 to_update["user_order"] = order
             if is_public is not None:
                 to_update["is_public"] = is_public
-            self.simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="group_summary_users",
                 keyvalues={
@@ -531,7 +533,7 @@ class GroupServerStore(SQLBaseStore):
             if is_public is None:
                 is_public = True
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_summary_users",
                 values={
@@ -547,7 +549,7 @@ class GroupServerStore(SQLBaseStore):
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
 
-        return self.simple_delete(
+        return self.db.simple_delete(
             table="group_summary_users",
             keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
             desc="remove_user_from_summary",
@@ -561,7 +563,7 @@ class GroupServerStore(SQLBaseStore):
             Deferred[list[str]]: A twisted.Deferred containing a list of group ids
                 containing this room
         """
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="group_rooms",
             keyvalues={"room_id": room_id},
             retcol="group_id",
@@ -625,12 +627,12 @@ class GroupServerStore(SQLBaseStore):
 
             return users, roles
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_for_summary_by_role", _get_users_for_summary_txn
         )
 
     def is_user_in_group(self, user_id, group_id):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
@@ -639,7 +641,7 @@ class GroupServerStore(SQLBaseStore):
         ).addCallback(lambda r: bool(r))
 
     def is_user_admin_in_group(self, group_id, user_id):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="is_admin",
@@ -650,7 +652,7 @@ class GroupServerStore(SQLBaseStore):
     def add_group_invite(self, group_id, user_id):
         """Record that the group server has invited a user
         """
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="group_invites",
             values={"group_id": group_id, "user_id": user_id},
             desc="add_group_invite",
@@ -659,7 +661,7 @@ class GroupServerStore(SQLBaseStore):
     def is_user_invited_to_local_group(self, group_id, user_id):
         """Has the group server invited a user?
         """
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
@@ -682,7 +684,7 @@ class GroupServerStore(SQLBaseStore):
         """
 
         def _get_users_membership_in_group_txn(txn):
-            row = self.simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="group_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
@@ -697,7 +699,7 @@ class GroupServerStore(SQLBaseStore):
                     "is_privileged": row["is_admin"],
                 }
 
-            row = self.simple_select_one_onecol_txn(
+            row = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
@@ -710,7 +712,7 @@ class GroupServerStore(SQLBaseStore):
 
             return {}
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_membership_info_in_group", _get_users_membership_in_group_txn
         )
 
@@ -738,7 +740,7 @@ class GroupServerStore(SQLBaseStore):
         """
 
         def _add_user_to_group_txn(txn):
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_users",
                 values={
@@ -749,14 +751,14 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
 
             if local_attestation:
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="group_attestations_renewals",
                     values={
@@ -766,7 +768,7 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
             if remote_attestation:
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="group_attestations_remote",
                     values={
@@ -777,49 +779,49 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
 
-        return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
+        return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
 
     def remove_user_from_group(self, group_id, user_id):
         def _remove_user_from_group_txn(txn):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_attestations_renewals",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_attestations_remote",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_summary_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_user_from_group", _remove_user_from_group_txn
         )
 
     def add_room_to_group(self, group_id, room_id, is_public):
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="group_rooms",
             values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
             desc="add_room_to_group",
         )
 
     def update_room_in_group_visibility(self, group_id, room_id, is_public):
-        return self.simple_update(
+        return self.db.simple_update(
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
             updatevalues={"is_public": is_public},
@@ -828,26 +830,26 @@ class GroupServerStore(SQLBaseStore):
 
     def remove_room_from_group(self, group_id, room_id):
         def _remove_room_from_group_txn(txn):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_rooms",
                 keyvalues={"group_id": group_id, "room_id": room_id},
             )
 
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_summary_rooms",
                 keyvalues={"group_id": group_id, "room_id": room_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_room_from_group", _remove_room_from_group_txn
         )
 
     def get_publicised_groups_for_user(self, user_id):
         """Get all groups a user is publicising
         """
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
             retcol="group_id",
@@ -857,7 +859,7 @@ class GroupServerStore(SQLBaseStore):
     def update_group_publicity(self, group_id, user_id, publicise):
         """Update whether the user is publicising their membership of the group
         """
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="local_group_membership",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"is_publicised": publicise},
@@ -893,12 +895,12 @@ class GroupServerStore(SQLBaseStore):
 
         def _register_user_group_membership_txn(txn, next_id):
             # TODO: Upsert?
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="local_group_membership",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="local_group_membership",
                 values={
@@ -911,7 +913,7 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="local_group_updates",
                 values={
@@ -930,7 +932,7 @@ class GroupServerStore(SQLBaseStore):
 
             if membership == "join":
                 if local_attestation:
-                    self.simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="group_attestations_renewals",
                         values={
@@ -940,7 +942,7 @@ class GroupServerStore(SQLBaseStore):
                         },
                     )
                 if remote_attestation:
-                    self.simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="group_attestations_remote",
                         values={
@@ -951,12 +953,12 @@ class GroupServerStore(SQLBaseStore):
                         },
                     )
             else:
-                self.simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="group_attestations_renewals",
                     keyvalues={"group_id": group_id, "user_id": user_id},
                 )
-                self.simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="group_attestations_remote",
                     keyvalues={"group_id": group_id, "user_id": user_id},
@@ -965,7 +967,7 @@ class GroupServerStore(SQLBaseStore):
             return next_id
 
         with self._group_updates_id_gen.get_next() as next_id:
-            res = yield self.runInteraction(
+            res = yield self.db.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
                 next_id,
@@ -976,7 +978,7 @@ class GroupServerStore(SQLBaseStore):
     def create_group(
         self, group_id, user_id, name, avatar_url, short_description, long_description
     ):
-        yield self.simple_insert(
+        yield self.db.simple_insert(
             table="groups",
             values={
                 "group_id": group_id,
@@ -991,7 +993,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def update_group_profile(self, group_id, profile):
-        yield self.simple_update_one(
+        yield self.db.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues=profile,
@@ -1008,16 +1010,16 @@ class GroupServerStore(SQLBaseStore):
                 WHERE valid_until_ms <= ?
             """
             txn.execute(sql, (valid_until_ms,))
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_attestations_need_renewals", _get_attestations_need_renewals_txn
         )
 
     def update_attestation_renewal(self, group_id, user_id, attestation):
         """Update an attestation that we have renewed
         """
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1027,7 +1029,7 @@ class GroupServerStore(SQLBaseStore):
     def update_remote_attestion(self, group_id, user_id, attestation):
         """Update an attestation that a remote has renewed
         """
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={
@@ -1046,7 +1048,7 @@ class GroupServerStore(SQLBaseStore):
             group_id (str)
             user_id (str)
         """
-        return self.simple_delete(
+        return self.db.simple_delete(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             desc="remove_attestation_renewal",
@@ -1057,7 +1059,7 @@ class GroupServerStore(SQLBaseStore):
         """Get the attestation that proves the remote agrees that the user is
         in the group.
         """
-        row = yield self.simple_select_one(
+        row = yield self.db.simple_select_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcols=("valid_until_ms", "attestation_json"),
@@ -1072,7 +1074,7 @@ class GroupServerStore(SQLBaseStore):
         return None
 
     def get_joined_groups(self, user_id):
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join"},
             retcol="group_id",
@@ -1099,7 +1101,7 @@ class GroupServerStore(SQLBaseStore):
                 for row in txn
             ]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_groups_for_user", _get_all_groups_for_user_txn
         )
 
@@ -1109,7 +1111,7 @@ class GroupServerStore(SQLBaseStore):
             user_id, from_token
         )
         if not has_changed:
-            return []
+            return defer.succeed([])
 
         def _get_groups_changes_for_user_txn(txn):
             sql = """
@@ -1129,7 +1131,7 @@ class GroupServerStore(SQLBaseStore):
                 for group_id, membership, gtype, content_json in txn
             ]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_groups_changes_for_user", _get_groups_changes_for_user_txn
         )
 
@@ -1139,7 +1141,7 @@ class GroupServerStore(SQLBaseStore):
             from_token
         )
         if not has_changed:
-            return []
+            return defer.succeed([])
 
         def _get_all_groups_changes_txn(txn):
             sql = """
@@ -1154,7 +1156,7 @@ class GroupServerStore(SQLBaseStore):
                 for stream_id, group_id, user_id, gtype, content_json in txn
             ]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_groups_changes", _get_all_groups_changes_txn
         )
 
@@ -1188,8 +1190,8 @@ class GroupServerStore(SQLBaseStore):
             ]
 
             for table in tables:
-                self.simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn, table=table, keyvalues={"group_id": group_id}
                 )
 
-        return self.runInteraction("delete_group", _delete_group_txn)
+        return self.db.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index c7150432b3..6b12f5a75f 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
                 _get_keys(txn, batch)
             return keys
 
-        return self.runInteraction("get_server_verify_keys", _txn)
+        return self.db.runInteraction("get_server_verify_keys", _txn)
 
     def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
         """Stores NACL verification keys for remote servers.
@@ -127,9 +127,9 @@ class KeyStore(SQLBaseStore):
                 f((i,))
             return res
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "store_server_verify_keys",
-            self.simple_upsert_many_txn,
+            self.db.simple_upsert_many_txn,
             table="server_signature_keys",
             key_names=("server_name", "key_id"),
             key_values=key_values,
@@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore):
             ts_valid_until_ms (int): The time when this json stops being valid.
             key_json (bytes): The encoded JSON.
         """
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             table="server_keys_json",
             keyvalues={
                 "server_name": server_name,
@@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore):
                     keyvalues["key_id"] = key_id
                 if from_server is not None:
                     keyvalues["from_server"] = from_server
-                rows = self.simple_select_list_txn(
+                rows = self.db.simple_select_list_txn(
                     txn,
                     "server_keys_json",
                     keyvalues=keyvalues,
@@ -211,4 +211,4 @@ class KeyStore(SQLBaseStore):
                 results[(server_name, key_id, from_server)] = rows
             return results
 
-        return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
+        return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 0cb9446f96..80ca36dedf 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -12,14 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 
 
-class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+            database, db_conn, hs
+        )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             update_name="local_media_repository_url_idx",
             index_name="local_media_repository_url_idx",
             table="local_media_repository",
@@ -31,15 +34,15 @@ class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
         Returns:
             None if the media_id doesn't exist.
         """
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -64,7 +67,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         user_id,
         url_cache=None,
     ):
-        return self.simple_insert(
+        return self.db.simple_insert(
             "local_media_repository",
             {
                 "media_id": media_id,
@@ -124,12 +127,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 )
             )
 
-        return self.runInteraction("get_url_cache", get_url_cache_txn)
+        return self.db.runInteraction("get_url_cache", get_url_cache_txn)
 
     def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
     ):
-        return self.simple_insert(
+        return self.db.simple_insert(
             "local_media_repository_url_cache",
             {
                 "url": url,
@@ -144,7 +147,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     def get_local_media_thumbnails(self, media_id):
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             "local_media_repository_thumbnails",
             {"media_id": media_id},
             (
@@ -166,7 +169,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self.simple_insert(
+        return self.db.simple_insert(
             "local_media_repository_thumbnails",
             {
                 "media_id": media_id,
@@ -180,7 +183,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     def get_cached_remote_media(self, origin, media_id):
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -205,7 +208,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         upload_name,
         filesystem_id,
     ):
-        return self.simple_insert(
+        return self.db.simple_insert(
             "remote_media_cache",
             {
                 "media_origin": origin,
@@ -250,10 +253,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+        return self.db.runInteraction(
+            "update_cached_last_access_time", update_cache_txn
+        )
 
     def get_remote_media_thumbnails(self, origin, media_id):
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             "remote_media_cache_thumbnails",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -278,7 +283,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self.simple_insert(
+        return self.db.simple_insert(
             "remote_media_cache_thumbnails",
             {
                 "media_origin": origin,
@@ -300,24 +305,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             " WHERE last_access_ts < ?"
         )
 
-        return self.execute(
-            "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+        return self.db.execute(
+            "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
         )
 
     def delete_remote_media(self, media_origin, media_id):
         def delete_remote_media_txn(txn):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 "remote_media_cache",
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 "remote_media_cache_thumbnails",
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
 
-        return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+        return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
 
     def get_expired_url_cache(self, now_ts):
         sql = (
@@ -331,7 +336,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
-        return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+        return self.db.runInteraction(
+            "get_expired_url_cache", _get_expired_url_cache_txn
+        )
 
     def delete_url_cache(self, media_ids):
         if len(media_ids) == 0:
@@ -342,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_txn(txn):
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+        return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
 
     def get_url_cache_media_before(self, before_ts):
         sql = (
@@ -356,7 +363,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
@@ -373,6 +380,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
         )
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index b8fc28f97b..27158534cb 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -17,6 +17,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersStore(SQLBaseStore):
-    def __init__(self, dbconn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(None, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
         # Do not add more reserved users than the total allowable number
-        self.new_transaction(
-            dbconn,
+        self.db.new_transaction(
+            db_conn,
             "initialise_mau_threepids",
             [],
             [],
@@ -146,7 +147,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
                     txn.execute(sql, query_args)
 
         reserved_users = yield self.get_registered_reserved_users()
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "reap_monthly_active_users", _reap_users, reserved_users
         )
         # It seems poor to invalidate the whole cache, Postgres supports
@@ -174,7 +175,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        return self.runInteraction("count_users", _count_users)
+        return self.db.runInteraction("count_users", _count_users)
 
     @defer.inlineCallbacks
     def get_registered_reserved_users(self):
@@ -217,7 +218,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         if is_support:
             return
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
@@ -261,7 +262,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         # never be a big table and alternative approaches (batching multiple
         # upserts into a single txn) introduced a lot of extra complexity.
         # See https://github.com/matrix-org/synapse/issues/3854 for more
-        is_insert = self.simple_upsert_txn(
+        is_insert = self.db.simple_upsert_txn(
             txn,
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
@@ -281,7 +282,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
 
         """
 
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
             retcol="timestamp",
diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py
index 650e49750e..cc21437e92 100644
--- a/synapse/storage/data_stores/main/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore
 
 class OpenIdStore(SQLBaseStore):
     def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="open_id_tokens",
             values={
                 "token": token,
@@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
             else:
                 return rows[0][0]
 
-        return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
+        return self.db.runInteraction(
+            "get_user_id_for_token", get_user_id_for_token_txn
+        )
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index a5e121efd1..a2c83e0867 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -29,7 +29,7 @@ class PresenceStore(SQLBaseStore):
         )
 
         with stream_ordering_manager as stream_orderings:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "update_presence",
                 self._update_presence_txn,
                 stream_orderings,
@@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore):
             txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
 
         # Actually insert new rows
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="presence_stream",
             values=[
@@ -88,7 +88,7 @@ class PresenceStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id))
             return txn.fetchall()
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_presence_updates", get_all_presence_updates_txn
         )
 
@@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_presence_for_users(self, user_ids):
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
             iterable=user_ids,
@@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore):
         return self._presence_id_gen.get_current_token()
 
     def allow_presence_visible(self, observed_localpart, observer_userid):
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="presence_allow_inbound",
             values={
                 "observed_user_id": observed_localpart,
@@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore):
         )
 
     def disallow_presence_visible(self, observed_localpart, observer_userid):
-        return self.simple_delete_one(
+        return self.db.simple_delete_one(
             table="presence_allow_inbound",
             keyvalues={
                 "observed_user_id": observed_localpart,
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index c8b5b60301..2b52cf9c1a 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_profileinfo(self, user_localpart):
         try:
-            profile = yield self.simple_select_one(
+            profile = yield self.db.simple_select_one(
                 table="profiles",
                 keyvalues={"user_id": user_localpart},
                 retcols=("displayname", "avatar_url"),
@@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_profile_displayname(self, user_localpart):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="displayname",
@@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_profile_avatar_url(self, user_localpart):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="avatar_url",
@@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_from_remote_profile_cache(self, user_id):
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             retcols=("displayname", "avatar_url"),
@@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def create_profile(self, user_localpart):
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="profiles", values={"user_id": user_localpart}, desc="create_profile"
         )
 
     def set_profile_displayname(self, user_localpart, new_displayname):
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"displayname": new_displayname},
@@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"avatar_url": new_avatar_url},
@@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
         This should only be called when `is_subscribed_remote_profile_for_user`
         would return true for the user.
         """
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             values={
@@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore):
         )
 
     def update_remote_profile_cache(self, user_id, displayname, avatar_url):
-        return self.simple_update(
+        return self.db.simple_update(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             values={
@@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
         """
         subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
         if not subscribed:
-            yield self.simple_delete(
+            yield self.db.simple_delete(
                 table="remote_profile_cache",
                 keyvalues={"user_id": user_id},
                 desc="delete_remote_profile_cache",
@@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore):
 
             txn.execute(sql, (last_checked,))
 
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_remote_profile_cache_entries_that_expire",
             _get_remote_profile_cache_entries_that_expire_txn,
         )
@@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
     def is_subscribed_remote_profile_for_user(self, user_id):
         """Check whether we are interested in a remote user's profile.
         """
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="group_users",
             keyvalues={"user_id": user_id},
             retcol="user_id",
@@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
         if res:
             return True
 
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"user_id": user_id},
             retcol="user_id",
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 75bd499bcd..5ba13aa973 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -72,10 +73,10 @@ class PushRulesWorkerStore(
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
-        push_rules_prefill, push_rules_id = self.get_cache_dict(
+        push_rules_prefill, push_rules_id = self.db.get_cache_dict(
             db_conn,
             "push_rules_stream",
             entity_column="user_id",
@@ -100,7 +101,7 @@ class PushRulesWorkerStore(
 
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
-        rows = yield self.simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
             retcols=(
@@ -124,7 +125,7 @@ class PushRulesWorkerStore(
 
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_enabled_for_user(self, user_id):
-        results = yield self.simple_select_list(
+        results = yield self.db.simple_select_list(
             table="push_rules_enable",
             keyvalues={"user_name": user_id},
             retcols=("user_name", "rule_id", "enabled"),
@@ -146,7 +147,7 @@ class PushRulesWorkerStore(
                 (count,) = txn.fetchone()
                 return bool(count)
 
-            return self.runInteraction(
+            return self.db.runInteraction(
                 "have_push_rules_changed", have_push_rules_changed_txn
             )
 
@@ -162,7 +163,7 @@ class PushRulesWorkerStore(
 
         results = {user_id: [] for user_id in user_ids}
 
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="push_rules",
             column="user_name",
             iterable=user_ids,
@@ -320,7 +321,7 @@ class PushRulesWorkerStore(
 
         results = {user_id: {} for user_id in user_ids}
 
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="push_rules_enable",
             column="user_name",
             iterable=user_ids,
@@ -350,7 +351,7 @@ class PushRuleStore(PushRulesWorkerStore):
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
             if before or after:
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "_add_push_rule_relative_txn",
                     self._add_push_rule_relative_txn,
                     stream_id,
@@ -364,7 +365,7 @@ class PushRuleStore(PushRulesWorkerStore):
                     after,
                 )
             else:
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "_add_push_rule_highest_priority_txn",
                     self._add_push_rule_highest_priority_txn,
                     stream_id,
@@ -395,7 +396,7 @@ class PushRuleStore(PushRulesWorkerStore):
 
         relative_to_rule = before or after
 
-        res = self.simple_select_one_txn(
+        res = self.db.simple_select_one_txn(
             txn,
             table="push_rules",
             keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@@ -518,7 +519,7 @@ class PushRuleStore(PushRulesWorkerStore):
             # We didn't update a row with the given rule_id so insert one
             push_rule_id = self._push_rule_id_gen.get_next()
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="push_rules",
                 values={
@@ -561,7 +562,7 @@ class PushRuleStore(PushRulesWorkerStore):
         """
 
         def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
-            self.simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
             )
 
@@ -571,7 +572,7 @@ class PushRuleStore(PushRulesWorkerStore):
 
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "delete_push_rule",
                 delete_push_rule_txn,
                 stream_id,
@@ -582,7 +583,7 @@ class PushRuleStore(PushRulesWorkerStore):
     def set_push_rule_enabled(self, user_id, rule_id, enabled):
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
                 stream_id,
@@ -596,7 +597,7 @@ class PushRuleStore(PushRulesWorkerStore):
         self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
     ):
         new_id = self._push_rules_enable_id_gen.get_next()
-        self.simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             "push_rules_enable",
             {"user_name": user_id, "rule_id": rule_id},
@@ -636,7 +637,7 @@ class PushRuleStore(PushRulesWorkerStore):
                     update_stream=False,
                 )
             else:
-                self.simple_update_one_txn(
+                self.db.simple_update_one_txn(
                     txn,
                     "push_rules",
                     {"user_name": user_id, "rule_id": rule_id},
@@ -655,7 +656,7 @@ class PushRuleStore(PushRulesWorkerStore):
 
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "set_push_rule_actions",
                 set_push_rule_actions_txn,
                 stream_id,
@@ -675,7 +676,7 @@ class PushRuleStore(PushRulesWorkerStore):
         if data is not None:
             values.update(data)
 
-        self.simple_insert_txn(txn, "push_rules_stream", values=values)
+        self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
 
         txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
         txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
@@ -699,7 +700,7 @@ class PushRuleStore(PushRulesWorkerStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_push_rule_updates", get_all_push_rule_updates_txn
         )
 
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index d5a169872b..f07309ef09 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_has_pusher(self, user_id):
-        ret = yield self.simple_select_one_onecol(
+        ret = yield self.db.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
@@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_pushers_by(self, keyvalues):
-        ret = yield self.simple_select_list(
+        ret = yield self.db.simple_select_list(
             "pushers",
             keyvalues,
             [
@@ -100,11 +100,11 @@ class PusherWorkerStore(SQLBaseStore):
     def get_all_pushers(self):
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        rows = yield self.runInteraction("get_all_pushers", get_pushers)
+        rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
         return rows
 
     def get_all_updated_pushers(self, last_id, current_id, limit):
@@ -134,7 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
 
             return updated, deleted
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_pushers", get_all_updated_pushers_txn
         )
 
@@ -177,7 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
 
             return results
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
@@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_if_users_have_pushers(self, user_ids):
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="pushers",
             column="user_name",
             iterable=user_ids,
@@ -230,7 +230,7 @@ class PusherStore(PusherWorkerStore):
         with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
-            yield self.simple_upsert(
+            yield self.db.simple_upsert(
                 table="pushers",
                 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
                 values={
@@ -255,7 +255,7 @@ class PusherStore(PusherWorkerStore):
 
             if user_has_pusher is not True:
                 # invalidate, since we the user might not have had a pusher before
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "add_pusher",
                     self._invalidate_cache_and_stream,
                     self.get_if_user_has_pusher,
@@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore):
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self.simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore):
             # it's possible for us to end up with duplicate rows for
             # (app_id, pushkey, user_id) at different stream_ids, but that
             # doesn't really matter.
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="deleted_pushers",
                 values={
@@ -290,13 +290,13 @@ class PusherStore(PusherWorkerStore):
             )
 
         with self._pushers_id_gen.get_next() as stream_id:
-            yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
+            yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
 
     @defer.inlineCallbacks
     def update_pusher_last_stream_ordering(
         self, app_id, pushkey, user_id, last_stream_ordering
     ):
-        yield self.simple_update_one(
+        yield self.db.simple_update_one(
             "pushers",
             {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             {"last_stream_ordering": last_stream_ordering},
@@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore):
         Returns:
             Deferred[bool]: True if the pusher still exists; False if it has been deleted.
         """
-        updated = yield self.simple_update(
+        updated = yield self.db.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={
@@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore):
 
     @defer.inlineCallbacks
     def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
-        yield self.simple_update(
+        yield self.db.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={"failing_since": failing_since},
@@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore):
 
     @defer.inlineCallbacks
     def get_throttle_params_by_room(self, pusher_id):
-        res = yield self.simple_select_list(
+        res = yield self.db.simple_select_list(
             "pusher_throttle",
             {"pusher": pusher_id},
             ["room_id", "last_sent_ts", "throttle_ms"],
@@ -362,7 +362,7 @@ class PusherStore(PusherWorkerStore):
     def set_throttle_params(self, pusher_id, room_id, params):
         # no need to lock because `pusher_throttle` has a primary key on
         # (pusher, room_id) so simple_upsert will retry
-        yield self.simple_upsert(
+        yield self.db.simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
             params,
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 380f388e30..96e54d145e 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -61,7 +62,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             table="receipts_linearized",
             keyvalues={"room_id": room_id, "receipt_type": receipt_type},
             retcols=("user_id", "event_id"),
@@ -70,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3)
     def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
                 "room_id": room_id,
@@ -84,7 +85,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=2)
     def get_receipts_for_user(self, user_id, receipt_type):
-        rows = yield self.simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
             retcols=("room_id", "event_id"),
@@ -108,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return txn.fetchall()
 
-        rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
+        rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
         return {
             row[0]: {
                 "event_id": row[1],
@@ -187,11 +188,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
                 txn.execute(sql, (room_id, to_key))
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             return rows
 
-        rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
+        rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
             return []
@@ -237,9 +238,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
                 txn.execute(sql + clause, [to_key] + list(args))
 
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
+        txn_results = yield self.db.runInteraction(
+            "_get_linearized_receipts_for_rooms", f
+        )
 
         results = {}
         for row in txn_results:
@@ -282,7 +285,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return list(r[0:5] + (json.loads(r[5]),) for r in txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_receipts", get_all_updated_receipts_txn
         )
 
@@ -313,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
 
 class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = StreamIdGenerator(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(db_conn, hs)
+        super(ReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
@@ -335,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
         """
-        res = self.simple_select_one_txn(
+        res = self.db.simple_select_one_txn(
             txn,
             table="events",
             retcols=["stream_ordering", "received_ts"],
@@ -388,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             (user_id, room_id, receipt_type),
         )
 
-        self.simple_delete_txn(
+        self.db.simple_delete_txn(
             txn,
             table="receipts_linearized",
             keyvalues={
@@ -398,7 +401,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             },
         )
 
-        self.simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="receipts_linearized",
             values={
@@ -453,13 +456,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 else:
                     raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
 
-            linearized_event_id = yield self.runInteraction(
+            linearized_event_id = yield self.db.runInteraction(
                 "insert_receipt_conv", graph_to_linear
             )
 
         stream_id_manager = self._receipts_id_gen.get_next()
         with stream_id_manager as stream_id:
-            event_ts = yield self.runInteraction(
+            event_ts = yield self.db.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
                 room_id,
@@ -488,7 +491,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
         return stream_id, max_persisted_id
 
     def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
@@ -514,7 +517,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
         )
 
-        self.simple_delete_txn(
+        self.db.simple_delete_txn(
             txn,
             table="receipts_graph",
             keyvalues={
@@ -523,7 +526,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "user_id": user_id,
             },
         )
-        self.simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="receipts_graph",
             values={
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index debc6706f5..5e8ecac0ea 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -26,8 +26,8 @@ from twisted.internet.defer import Deferred
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -37,15 +37,15 @@ logger = logging.getLogger(__name__)
 
 
 class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
 
     @cached()
     def get_user_by_id(self, user_id):
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
             retcols=[
@@ -94,7 +94,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 including the keys `name`, `is_guest`, `device_id`, `token_id`,
                 `valid_until_ms`.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
         )
 
@@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 otherwise int representation of the timestamp (as a number of
                 milliseconds since epoch).
         """
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
@@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def set_account_validity_for_user_txn(txn):
-            self.simple_update_txn(
+            self.db.simple_update_txn(
                 txn=txn,
                 table="account_validity",
                 keyvalues={"user_id": user_id},
@@ -151,7 +151,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_expiration_ts_for_user, (user_id,)
             )
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "set_account_validity_for_user", set_account_validity_for_user_txn
         )
 
@@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Raises:
             StoreError: The provided token is already set for another user.
         """
-        yield self.simple_update_one(
+        yield self.db.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"renewal_token": renewal_token},
@@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             defer.Deferred[str]: The ID of the user to which the token belongs.
         """
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"renewal_token": renewal_token},
             retcol="user_id",
@@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             defer.Deferred[str]: The renewal token associated with this user ID.
         """
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="renewal_token",
@@ -229,9 +229,9 @@ class RegistrationWorkerStore(SQLBaseStore):
             )
             values = [False, now_ms, renew_at]
             txn.execute(sql, values)
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        res = yield self.runInteraction(
+        res = yield self.db.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
             self.clock.time_msec(),
@@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             email_sent (bool): Flag which indicates whether a renewal email has been sent
                 to this user.
         """
-        yield self.simple_update_one(
+        yield self.db.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"email_sent": email_sent},
@@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Args:
             user_id (str): ID of the user to remove from the account validity table.
         """
-        yield self.simple_delete_one(
+        yield self.db.simple_delete_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             desc="delete_account_validity_for_user",
@@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns (bool):
             true iff the user is a server admin, false otherwise.
         """
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user.to_string()},
             retcol="admin",
@@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             admin (bool): true iff the user is to be a server admin,
                 false otherwise.
         """
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="users",
             keyvalues={"name": user.to_string()},
             updatevalues={"admin": 1 if admin else 0},
@@ -316,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, (token,))
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
         if rows:
             return rows[0]
 
@@ -332,7 +332,9 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if user 'user_type' is null or empty string
         """
-        res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
+        res = yield self.db.runInteraction(
+            "is_real_user", self.is_real_user_txn, user_id
+        )
         return res
 
     @cachedInlineCallbacks()
@@ -345,13 +347,13 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if user is of type UserTypes.SUPPORT
         """
-        res = yield self.runInteraction(
+        res = yield self.db.runInteraction(
             "is_support_user", self.is_support_user_txn, user_id
         )
         return res
 
     def is_real_user_txn(self, txn, user_id):
-        res = self.simple_select_one_onecol_txn(
+        res = self.db.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -361,7 +363,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return res is None
 
     def is_support_user_txn(self, txn, user_id):
-        res = self.simple_select_one_onecol_txn(
+        res = self.db.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -380,7 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return dict(txn)
 
-        return self.runInteraction("get_users_by_id_case_insensitive", f)
+        return self.db.runInteraction("get_users_by_id_case_insensitive", f)
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
@@ -394,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             str|None: the mxid of the user, or None if they are not known
         """
-        return await self.simple_select_one_onecol(
+        return await self.db.simple_select_one_onecol(
             table="user_external_ids",
             keyvalues={"auth_provider": auth_provider, "external_id": external_id},
             retcol="user_id",
@@ -408,12 +410,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if rows:
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.runInteraction("count_users", _count_users)
+        ret = yield self.db.runInteraction("count_users", _count_users)
         return ret
 
     def count_daily_user_type(self):
@@ -445,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 results[row[0]] = row[1]
             return results
 
-        return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+        return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
 
     @defer.inlineCallbacks
     def count_nonbridged_users(self):
@@ -459,7 +461,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_users", _count_users)
+        ret = yield self.db.runInteraction("count_users", _count_users)
         return ret
 
     @defer.inlineCallbacks
@@ -468,12 +470,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if rows:
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.runInteraction("count_real_users", _count_users)
+        ret = yield self.db.runInteraction("count_real_users", _count_users)
         return ret
 
     @defer.inlineCallbacks
@@ -503,7 +505,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return (
             (
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "find_next_generated_user_id", _find_next_generated_user_id
                 )
             )
@@ -520,7 +522,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[str|None]: user id or None if no user id/threepid mapping exists
         """
-        user_id = yield self.runInteraction(
+        user_id = yield self.db.runInteraction(
             "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
         )
         return user_id
@@ -536,7 +538,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             str|None: user id or None if no user id/threepid mapping exists
         """
-        ret = self.simple_select_one_txn(
+        ret = self.db.simple_select_one_txn(
             txn,
             "user_threepids",
             {"medium": medium, "address": address},
@@ -549,7 +551,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
-        yield self.simple_upsert(
+        yield self.db.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@@ -557,7 +559,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_get_threepids(self, user_id):
-        ret = yield self.simple_select_list(
+        ret = yield self.db.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
             ["medium", "address", "validated_at", "added_at"],
@@ -566,7 +568,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return ret
 
     def user_delete_threepid(self, user_id, medium, address):
-        return self.simple_delete(
+        return self.db.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             desc="user_delete_threepid",
@@ -579,7 +581,7 @@ class RegistrationWorkerStore(SQLBaseStore):
              user_id: The user id to delete all threepids of
 
         """
-        return self.simple_delete(
+        return self.db.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id},
             desc="user_delete_threepids",
@@ -601,7 +603,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
         # We need to use an upsert, in case they user had already bound the
         # threepid
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -627,7 +629,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 medium (str): The medium of the threepid (e.g "email")
                 address (str): The address of the threepid (e.g "bob@example.com")
         """
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id},
             retcols=["medium", "address"],
@@ -648,7 +650,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred
         """
-        return self.simple_delete(
+        return self.db.simple_delete(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -671,7 +673,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[list[str]]: Resolves to a list of identity servers
         """
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             retcol="id_server",
@@ -689,7 +691,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             defer.Deferred(bool): The requested value.
         """
 
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="deactivated",
@@ -756,13 +758,13 @@ class RegistrationWorkerStore(SQLBaseStore):
             sql += " LIMIT 1"
 
             txn.execute(sql, list(keyvalues.values()))
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return None
 
             return rows[0]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_threepid_validation_session", get_threepid_validation_session_txn
         )
 
@@ -776,39 +778,37 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def delete_threepid_session_txn(txn):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="threepid_validation_token",
                 keyvalues={"session_id": session_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_threepid_session", delete_threepid_session_txn
         )
 
 
-class RegistrationBackgroundUpdateStore(
-    RegistrationWorkerStore, background_updates.BackgroundUpdateStore
-):
-    def __init__(self, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
+class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "access_tokens_device_index",
             index_name="access_tokens_device_id",
             table="access_tokens",
             columns=["user_id", "device_id"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "users_creation_ts",
             index_name="users_creation_ts",
             table="users",
@@ -818,13 +818,13 @@ class RegistrationBackgroundUpdateStore(
         # we no longer use refresh tokens, but it's possible that some people
         # might have a background update queued to build this index. Just
         # clear the background update.
-        self.register_noop_background_update("refresh_tokens_device_index")
+        self.db.updates.register_noop_background_update("refresh_tokens_device_index")
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_threepids_grandfather", self._bg_user_threepids_grandfather
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
@@ -857,7 +857,7 @@ class RegistrationBackgroundUpdateStore(
                 (last_user, batch_size),
             )
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             if not rows:
                 return True, 0
@@ -871,7 +871,7 @@ class RegistrationBackgroundUpdateStore(
 
             logger.info("Marked %d rows as deactivated", rows_processed_nb)
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
             )
 
@@ -880,12 +880,12 @@ class RegistrationBackgroundUpdateStore(
             else:
                 return False, len(rows)
 
-        end, nb_processed = yield self.runInteraction(
+        end, nb_processed = yield self.db.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
-            yield self._end_background_update("users_set_deactivated_flag")
+            yield self.db.updates._end_background_update("users_set_deactivated_flag")
 
         return nb_processed
 
@@ -911,18 +911,18 @@ class RegistrationBackgroundUpdateStore(
             txn.executemany(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
             )
 
-        yield self._end_background_update("user_threepids_grandfather")
+        yield self.db.updates._end_background_update("user_threepids_grandfather")
 
         return 1
 
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationStore, self).__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
 
@@ -961,7 +961,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
         next_id = self._access_tokens_id_gen.get_next()
 
-        yield self.simple_insert(
+        yield self.db.simple_insert(
             "access_tokens",
             {
                 "id": next_id,
@@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         Raises:
             StoreError if the user_id could not be registered.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "register_user",
             self._register_user,
             user_id,
@@ -1037,7 +1037,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 # Ensure that the guest user actually exists
                 # ``allow_none=False`` makes this raise an exception
                 # if the row isn't in the database.
-                self.simple_select_one_txn(
+                self.db.simple_select_one_txn(
                     txn,
                     "users",
                     keyvalues={"name": user_id, "is_guest": 1},
@@ -1045,7 +1045,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                     allow_none=False,
                 )
 
-                self.simple_update_one_txn(
+                self.db.simple_update_one_txn(
                     txn,
                     "users",
                     keyvalues={"name": user_id, "is_guest": 1},
@@ -1059,7 +1059,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                     },
                 )
             else:
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     "users",
                     values={
@@ -1114,7 +1114,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="user_external_ids",
             values={
                 "auth_provider": auth_provider,
@@ -1132,12 +1132,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
 
         def user_set_password_hash_txn(txn):
-            self.simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn, "users", {"name": user_id}, {"password_hash": password_hash}
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
+        return self.db.runInteraction(
+            "user_set_password_hash", user_set_password_hash_txn
+        )
 
     def user_set_consent_version(self, user_id, consent_version):
         """Updates the user table to record privacy policy consent
@@ -1152,7 +1154,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
 
         def f(txn):
-            self.simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="users",
                 keyvalues={"name": user_id},
@@ -1160,7 +1162,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_consent_version", f)
+        return self.db.runInteraction("user_set_consent_version", f)
 
     def user_set_consent_server_notice_sent(self, user_id, consent_version):
         """Updates the user table to record that we have sent the user a server
@@ -1176,7 +1178,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
 
         def f(txn):
-            self.simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="users",
                 keyvalues={"name": user_id},
@@ -1184,7 +1186,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_consent_server_notice_sent", f)
+        return self.db.runInteraction("user_set_consent_server_notice_sent", f)
 
     def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
         """
@@ -1230,11 +1232,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
             return tokens_and_devices
 
-        return self.runInteraction("user_delete_access_tokens", f)
+        return self.db.runInteraction("user_delete_access_tokens", f)
 
     def delete_access_token(self, access_token):
         def f(txn):
-            self.simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
             )
 
@@ -1242,11 +1244,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn, self.get_user_by_access_token, (access_token,)
             )
 
-        return self.runInteraction("delete_access_token", f)
+        return self.db.runInteraction("delete_access_token", f)
 
     @cachedInlineCallbacks()
     def is_guest(self, user_id):
-        res = yield self.simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="is_guest",
@@ -1261,7 +1263,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         Adds a user to the table of users who need to be parted from all the rooms they're
         in
         """
-        return self.simple_insert(
+        return self.db.simple_insert(
             "users_pending_deactivation",
             values={"user_id": user_id},
             desc="add_user_pending_deactivation",
@@ -1274,7 +1276,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
         # XXX: This should be simple_delete_one but we failed to put a unique index on
         # the table, so somehow duplicate entries have ended up in it.
-        return self.simple_delete(
+        return self.db.simple_delete(
             "users_pending_deactivation",
             keyvalues={"user_id": user_id},
             desc="del_user_pending_deactivation",
@@ -1285,7 +1287,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         Gets one user from the table of users waiting to be parted from all the rooms
         they're in.
         """
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             "users_pending_deactivation",
             keyvalues={},
             retcol="user_id",
@@ -1315,7 +1317,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         # Insert everything into a transaction in order to run atomically
         def validate_threepid_session_txn(txn):
-            row = self.simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1333,7 +1335,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                     400, "This client_secret does not match the provided session_id"
                 )
 
-            row = self.simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="threepid_validation_token",
                 keyvalues={"session_id": session_id, "token": token},
@@ -1358,7 +1360,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 )
 
             # Looks good. Validate the session
-            self.simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1368,7 +1370,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             return next_link
 
         # Return next_link if it exists
-        return self.runInteraction(
+        return self.db.runInteraction(
             "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
@@ -1401,7 +1403,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         if validated_at:
             insertion_values["validated_at"] = validated_at
 
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             table="threepid_validation_session",
             keyvalues={"session_id": session_id},
             values={"last_send_attempt": send_attempt},
@@ -1439,7 +1441,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         def start_or_continue_validation_session_txn(txn):
             # Create or update a validation session
-            self.simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1452,7 +1454,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
 
             # Create a new validation token with this session ID
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="threepid_validation_token",
                 values={
@@ -1463,7 +1465,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 },
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "start_or_continue_validation_session",
             start_or_continue_validation_session_txn,
         )
@@ -1478,7 +1480,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             """
             return txn.execute(sql, (ts,))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
             self.clock.time_msec(),
@@ -1493,7 +1495,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             deactivated (bool): The value to set for `deactivated`.
         """
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "set_user_deactivated_status",
             self.set_user_deactivated_status_txn,
             user_id,
@@ -1501,7 +1503,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         )
 
     def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self.simple_update_one_txn(
+        self.db.simple_update_one_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -1529,14 +1531,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             txn.execute(sql, [])
 
-            res = self.cursor_to_dict(txn)
+            res = self.db.cursor_to_dict(txn)
             if res:
                 for user in res:
                     self.set_expiration_date_for_user_txn(
                         txn, user["name"], use_delta=True
                     )
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "get_users_with_no_expiration_date",
             select_users_with_no_expiration_date_txn,
         )
@@ -1560,7 +1562,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 expiration_ts,
             )
 
-        self.simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             "account_validity",
             keyvalues={"user_id": user_id},
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index f81f9279a1..1c07c7a425 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
 
 class RejectionsStore(SQLBaseStore):
     def _store_rejections_txn(self, txn, event_id, reason):
-        self.simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="rejections",
             values={
@@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore):
         )
 
     def get_rejection_reason(self, event_id):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index aa5e10538b..046c2b4845 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
@@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
@@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore):
             if row:
                 return row[0]
 
-        edit_id = yield self.runInteraction(
+        edit_id = yield self.db.runInteraction(
             "get_applicable_edit", _get_applicable_edit_txn
         )
 
@@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             return bool(txn.fetchone())
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
@@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore):
 
         aggregation_key = relation.get("key")
 
-        self.simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="event_relations",
             values={
@@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore):
             redacted_event_id (str): The event that was redacted.
         """
 
-        self.simple_delete_txn(
+        self.db.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index f309e3640c..0148be20d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -28,8 +28,8 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database
 from synapse.types import ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -54,7 +54,7 @@ class RoomWorkerStore(SQLBaseStore):
         Returns:
             A dict containing the room information, or None if the room is unknown.
         """
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
             retcols=("room_id", "is_public", "creator"),
@@ -63,7 +63,7 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
     def get_public_room_ids(self):
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="rooms",
             keyvalues={"is_public": True},
             retcol="room_id",
@@ -120,7 +120,7 @@ class RoomWorkerStore(SQLBaseStore):
             txn.execute(sql, query_args)
             return txn.fetchone()[0]
 
-        return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
+        return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
 
     @defer.inlineCallbacks
     def get_largest_public_rooms(
@@ -253,21 +253,21 @@ class RoomWorkerStore(SQLBaseStore):
         def _get_largest_public_rooms_txn(txn):
             txn.execute(sql, query_args)
 
-            results = self.cursor_to_dict(txn)
+            results = self.db.cursor_to_dict(txn)
 
             if not forwards:
                 results.reverse()
 
             return results
 
-        ret_val = yield self.runInteraction(
+        ret_val = yield self.db.runInteraction(
             "get_largest_public_rooms", _get_largest_public_rooms_txn
         )
         defer.returnValue(ret_val)
 
     @cached(max_entries=10000)
     def is_room_blocked(self, room_id):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             retcol="1",
@@ -288,7 +288,7 @@ class RoomWorkerStore(SQLBaseStore):
             of RatelimitOverride are None or 0 then ratelimitng has been
             disabled for that user entirely.
         """
-        row = yield self.simple_select_one(
+        row = yield self.db.simple_select_one(
             table="ratelimit_override",
             keyvalues={"user_id": user_id},
             retcols=("messages_per_second", "burst_count"),
@@ -330,9 +330,9 @@ class RoomWorkerStore(SQLBaseStore):
                 (room_id,),
             )
 
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        ret = yield self.runInteraction(
+        ret = yield self.db.runInteraction(
             "get_retention_policy_for_room", get_retention_policy_for_room_txn,
         )
 
@@ -361,13 +361,13 @@ class RoomWorkerStore(SQLBaseStore):
         defer.returnValue(row)
 
 
-class RoomBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs)
+class RoomBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "insert_room_retention", self._background_insert_retention,
         )
 
@@ -396,7 +396,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
                 (last_room, batch_size),
             )
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             if not rows:
                 return True
@@ -408,7 +408,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
                     ev = json.loads(row["json"])
                     retention_policy = json.dumps(ev["content"])
 
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn=txn,
                     table="room_retention",
                     values={
@@ -421,7 +421,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
 
             logger.info("Inserted %d rows into room_retention", len(rows))
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
             )
 
@@ -430,19 +430,19 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
             else:
                 return False
 
-        end = yield self.runInteraction(
+        end = yield self.db.runInteraction(
             "insert_room_retention", _background_insert_retention_txn,
         )
 
         if end:
-            yield self._end_background_update("insert_room_retention")
+            yield self.db.updates._end_background_update("insert_room_retention")
 
         defer.returnValue(batch_size)
 
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -461,7 +461,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         try:
 
             def store_room_txn(txn, next_id):
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     "rooms",
                     {
@@ -471,7 +471,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
                 if is_public:
-                    self.simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="public_room_list_stream",
                         values={
@@ -482,7 +482,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     )
 
             with self._public_room_id_gen.get_next() as next_id:
-                yield self.runInteraction("store_room_txn", store_room_txn, next_id)
+                yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
         except Exception as e:
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
@@ -490,14 +490,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     @defer.inlineCallbacks
     def set_room_is_public(self, room_id, is_public):
         def set_room_is_public_txn(txn, next_id):
-            self.simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="rooms",
                 keyvalues={"room_id": room_id},
                 updatevalues={"is_public": is_public},
             )
 
-            entries = self.simple_select_list_txn(
+            entries = self.db.simple_select_list_txn(
                 txn,
                 table="public_room_list_stream",
                 keyvalues={
@@ -515,7 +515,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 add_to_stream = bool(entries[-1]["visibility"]) != is_public
 
             if add_to_stream:
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="public_room_list_stream",
                     values={
@@ -528,7 +528,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
         self.hs.get_notifier().on_new_replication_data()
@@ -555,7 +555,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         def set_room_is_public_appservice_txn(txn, next_id):
             if is_public:
                 try:
-                    self.simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="appservice_room_list",
                         values={
@@ -568,7 +568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     # We've already inserted, nothing to do.
                     return
             else:
-                self.simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="appservice_room_list",
                     keyvalues={
@@ -578,7 +578,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-            entries = self.simple_select_list_txn(
+            entries = self.db.simple_select_list_txn(
                 txn,
                 table="public_room_list_stream",
                 keyvalues={
@@ -596,7 +596,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 add_to_stream = bool(entries[-1]["visibility"]) != is_public
 
             if add_to_stream:
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="public_room_list_stream",
                     values={
@@ -609,7 +609,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
                 next_id,
@@ -626,7 +626,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             row = txn.fetchone()
             return row[0] or 0
 
-        return self.runInteraction("get_rooms", f)
+        return self.db.runInteraction("get_rooms", f)
 
     def _store_room_topic_txn(self, txn, event):
         if hasattr(event, "content") and "topic" in event.content:
@@ -660,7 +660,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 # Ignore the event if one of the value isn't an integer.
                 return
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn=txn,
                 table="room_retention",
                 values={
@@ -679,7 +679,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         self, room_id, event_id, user_id, reason, content, received_ts
     ):
         next_id = self._event_reports_id_gen.get_next()
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="event_reports",
             values={
                 "id": next_id,
@@ -712,7 +712,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         if prev_id == current_id:
             return defer.succeed([])
 
-        return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
+        return self.db.runInteraction(
+            "get_all_new_public_rooms", get_all_new_public_rooms
+        )
 
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
@@ -725,14 +727,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         Returns:
             Deferred
         """
-        yield self.simple_upsert(
+        yield self.db.simple_upsert(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             values={},
             insertion_values={"user_id": user_id},
             desc="block_room",
         )
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "block_room_invalidation",
             self._invalidate_cache_and_stream,
             self.is_room_blocked,
@@ -763,7 +765,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
             return local_media_mxcs, remote_media_mxcs
 
-        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+        return self.db.runInteraction(
+            "get_media_ids_in_room", _get_media_mxcs_in_room_txn
+        )
 
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
@@ -802,7 +806,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
             return total_media_quarantined
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
@@ -907,7 +911,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
             txn.execute(sql, args)
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             rooms_dict = {}
 
             for row in rows:
@@ -923,7 +927,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
                 txn.execute(sql)
 
-                rows = self.cursor_to_dict(txn)
+                rows = self.db.cursor_to_dict(txn)
 
                 # If a room isn't already in the dict (i.e. it doesn't have a retention
                 # policy in its state), add it with a null policy.
@@ -936,7 +940,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
             return rooms_dict
 
-        rooms = yield self.runInteraction(
+        rooms = yield self.db.runInteraction(
             "get_rooms_for_retention_period_in_range",
             get_rooms_for_retention_period_in_range_txn,
         )
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index fe2428a281..92e3b9c512 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -26,9 +26,13 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import (
+    LoggingTransaction,
+    SQLBaseStore,
+    make_in_list_sql_clause,
+)
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
     GetRoomsForUserWithStreamOrdering,
@@ -51,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
 
         # Is the current_state_events.membership up to date? Or is the
         # background update still running?
@@ -116,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(query)
             return list(txn)[0][0]
 
-        count = yield self.runInteraction("get_known_servers", _transact)
+        count = yield self.db.runInteraction("get_known_servers", _transact)
 
         # We always know about ourselves, even if we have nothing in
         # room_memberships (for example, the server is new).
@@ -128,7 +132,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         membership column is up to date
         """
 
-        pending_update = self.simple_select_one_txn(
+        pending_update = self.db.simple_select_one_txn(
             txn,
             table="background_updates",
             keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@@ -144,7 +148,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 15.0,
                 run_as_background_process,
                 "_check_safe_current_state_events_membership_updated",
-                self.runInteraction,
+                self.db.runInteraction,
                 "_check_safe_current_state_events_membership_updated",
                 self._check_safe_current_state_events_membership_updated_txn,
             )
@@ -161,7 +165,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
@@ -269,7 +273,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             return res
 
-        return self.runInteraction("get_room_summary", _get_room_summary_txn)
+        return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
 
     def _get_user_counts_in_room_txn(self, txn, room_id):
         """
@@ -339,7 +343,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not membership_list:
             return defer.succeed(None)
 
-        rooms = yield self.runInteraction(
+        rooms = yield self.db.runInteraction(
             "get_rooms_for_user_where_membership_is",
             self._get_rooms_for_user_where_membership_is_txn,
             user_id,
@@ -392,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 )
 
             txn.execute(sql, (user_id, *args))
-            results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
+            results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
 
         if do_invite:
             sql = (
@@ -412,7 +416,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                     stream_ordering=r["stream_ordering"],
                     membership=Membership.INVITE,
                 )
-                for r in self.cursor_to_dict(txn)
+                for r in self.db.cursor_to_dict(txn)
             )
 
         return results
@@ -603,7 +607,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             to `user_id` and ProfileInfo (or None if not join event).
         """
 
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
@@ -643,7 +647,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # the returned user actually has the correct domain.
         like_clause = "%:" + host
 
-        rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause)
+        rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
 
         if not rows:
             return False
@@ -683,7 +687,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # the returned user actually has the correct domain.
         like_clause = "%:" + host
 
-        rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause)
+        rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
 
         if not rows:
             return False
@@ -753,7 +757,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        count = yield self.runInteraction("did_forget_membership", f)
+        count = yield self.db.runInteraction("did_forget_membership", f)
         return count == 0
 
     @cached()
@@ -790,7 +794,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (user_id,))
             return set(row[0] for row in txn if row[1] == 0)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
@@ -805,7 +809,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             Deferred[set[str]]: Set of room IDs.
         """
 
-        room_ids = yield self.simple_select_onecol(
+        room_ids = yield self.db.simple_select_onecol(
             table="room_memberships",
             keyvalues={"membership": Membership.JOIN, "user_id": user_id},
             retcol="room_id",
@@ -820,7 +824,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Get user_id and membership of a set of event IDs.
         """
 
-        return self.simple_select_many_batch(
+        return self.db.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
@@ -831,17 +835,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
 
-class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
+class RoomMemberBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        self.db.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
             self._background_current_state_membership,
         )
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "room_membership_forgotten_idx",
             index_name="room_memberships_user_room_forgotten",
             table="room_memberships",
@@ -874,7 +878,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return 0
 
@@ -909,18 +913,20 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
                 "max_stream_id_exclusive": min_stream_id,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
             )
 
             return len(rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
         )
 
         if not result:
-            yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                _MEMBERSHIP_PROFILE_UPDATE_NAME
+            )
 
         return result
 
@@ -959,7 +965,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
 
                 last_processed_room = next_room
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn,
                 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
                 {"last_processed_room": last_processed_room},
@@ -971,26 +977,28 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
         # string, which will compare before all room IDs correctly.
         last_processed_room = progress.get("last_processed_room", "")
 
-        row_count, finished = yield self.runInteraction(
+        row_count, finished = yield self.db.runInteraction(
             "_background_current_state_membership_update",
             _background_current_state_membership_txn,
             last_processed_room,
         )
 
         if finished:
-            yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
+            )
 
         return row_count
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
         """
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="room_memberships",
             values=[
@@ -1028,7 +1036,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
             is_mine = self.hs.is_mine_id(event.state_key)
             if is_new_state and is_mine:
                 if event.membership == Membership.INVITE:
-                    self.simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="local_invites",
                         values={
@@ -1068,7 +1076,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
             txn.execute(sql, (stream_ordering, True, room_id, user_id))
 
         with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+            yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
 
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
@@ -1091,7 +1099,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                 txn, self.get_forgotten_rooms_for_user, (user_id,)
             )
 
-        return self.runInteraction("forget_membership", f)
+        return self.db.runInteraction("forget_membership", f)
 
 
 class _JoinedHostsCache(object):
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
index fe51b02309..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -14,4 +14,3 @@
  */
 
 ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
-CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
index 77a5eca499..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -14,7 +14,9 @@
  */
 
 ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
-CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
 
 INSERT INTO background_updates (update_name, progress_json) VALUES
   ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index f735cf095c..4eec2fae5e 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -24,8 +24,8 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
-from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
 logger = logging.getLogger(__name__)
@@ -36,23 +36,23 @@ SearchEntry = namedtuple(
 )
 
 
-class SearchBackgroundUpdateStore(BackgroundUpdateStore):
+class SearchBackgroundUpdateStore(SQLBaseStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, db_conn, hs):
-        super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         if not hs.config.enable_search:
             return
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
         )
 
@@ -61,9 +61,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
         # a GIN index. However, it's possible that some people might still have
         # the background update queued, so we register a handler to clear the
         # background update.
-        self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+        self.db.updates.register_noop_background_update(
+            self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
+        )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
         )
 
@@ -93,7 +95,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
             # store_search_entries_txn with a generator function, but that
             # would mean having two cursors open on the database at once.
             # Instead we just build a list of results.
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return 0
 
@@ -153,18 +155,18 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(event_search_rows),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_SEARCH_UPDATE_NAME, progress
             )
 
             return len(event_search_rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
+            yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
 
         return result
 
@@ -206,9 +208,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 conn.set_session(autocommit=False)
 
         if isinstance(self.database_engine, PostgresEngine):
-            yield self.runWithConnection(create_index)
+            yield self.db.runWithConnection(create_index)
 
-        yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
+        yield self.db.updates._end_background_update(
+            self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
+        )
         return 1
 
     @defer.inlineCallbacks
@@ -237,14 +241,14 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 )
                 conn.set_session(autocommit=False)
 
-            yield self.runWithConnection(create_index)
+            yield self.db.runWithConnection(create_index)
 
             pg = dict(progress)
             pg["have_added_indexes"] = True
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
                 pg,
             )
@@ -274,18 +278,20 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 "have_added_indexes": True,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
             )
 
             return len(rows), True
 
-        num_rows, finished = yield self.runInteraction(
+        num_rows, finished = yield self.db.runInteraction(
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
         )
 
         if not finished:
-            yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_SEARCH_ORDER_UPDATE_NAME
+            )
 
         return num_rows
 
@@ -337,8 +343,8 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(SearchStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchStore, self).__init__(database, db_conn, hs)
 
     def store_event_search_txn(self, txn, event, key, value):
         """Add event to the search table
@@ -441,7 +447,9 @@ class SearchStore(SearchBackgroundUpdateStore):
         # entire table from the database.
         sql += " ORDER BY rank DESC LIMIT 500"
 
-        results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args)
+        results = yield self.db.execute(
+            "search_msgs", self.db.cursor_to_dict, sql, *args
+        )
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
@@ -455,8 +463,8 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self.execute(
-            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        count_results = yield self.db.execute(
+            "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -586,7 +594,9 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         args.append(limit)
 
-        results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args)
+        results = yield self.db.execute(
+            "search_rooms", self.db.cursor_to_dict, sql, *args
+        )
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
@@ -600,8 +610,8 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self.execute(
-            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        count_results = yield self.db.execute(
+            "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -686,7 +696,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
             return highlight_words
 
-        return self.runInteraction("_find_highlights", f)
+        return self.db.runInteraction("_find_highlights", f)
 
 
 def _to_postgres_options(options_dict):
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index f3da29ce14..563216b63c 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -48,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore):
                 for event_id in event_ids
             }
 
-        return self.runInteraction("get_event_reference_hashes", f)
+        return self.db.runInteraction("get_event_reference_hashes", f)
 
     @defer.inlineCallbacks
     def add_event_hashes(self, event_ids):
@@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore):
                 }
             )
 
-        self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+        self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 2b33ec1a35..9ef7b48c74 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -27,8 +27,8 @@ from synapse.api.errors import NotFoundError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
 from synapse.util.caches import get_cache_factor_for, intern_string
@@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             count = 0
 
             while next_group:
-                next_group = self.simple_select_one_onecol_txn(
+                next_group = self.db.simple_select_one_onecol_txn(
                     txn,
                     table="state_group_edges",
                     keyvalues={"state_group": next_group},
@@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                     ):
                         break
 
-                    next_group = self.simple_select_one_onecol_txn(
+                    next_group = self.db.simple_select_one_onecol_txn(
                         txn,
                         table="state_group_edges",
                         keyvalues={"state_group": next_group},
@@ -214,8 +214,8 @@ class StateGroupWorkerStore(
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
-    def __init__(self, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
         # event IDs for the state types in a given state group to avoid hammering
@@ -348,7 +348,9 @@ class StateGroupWorkerStore(
                 (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
             }
 
-        return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
+        return self.db.runInteraction(
+            "get_current_state_ids", _get_current_state_ids_txn
+        )
 
     # FIXME: how should this be cached?
     def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@@ -392,7 +394,7 @@ class StateGroupWorkerStore(
 
             return results
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
@@ -431,7 +433,7 @@ class StateGroupWorkerStore(
         """
 
         def _get_state_group_delta_txn(txn):
-            prev_group = self.simple_select_one_onecol_txn(
+            prev_group = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="state_group_edges",
                 keyvalues={"state_group": state_group},
@@ -442,7 +444,7 @@ class StateGroupWorkerStore(
             if not prev_group:
                 return _GetStateGroupDelta(None, None)
 
-            delta_ids = self.simple_select_list_txn(
+            delta_ids = self.db.simple_select_list_txn(
                 txn,
                 table="state_groups_state",
                 keyvalues={"state_group": state_group},
@@ -454,7 +456,9 @@ class StateGroupWorkerStore(
                 {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
             )
 
-        return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+        return self.db.runInteraction(
+            "get_state_group_delta", _get_state_group_delta_txn
+        )
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, _room_id, event_ids):
@@ -540,7 +544,7 @@ class StateGroupWorkerStore(
 
         chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
-            res = yield self.runInteraction(
+            res = yield self.db.runInteraction(
                 "_get_state_groups_from_groups",
                 self._get_state_groups_from_groups_txn,
                 chunk,
@@ -644,7 +648,7 @@ class StateGroupWorkerStore(
 
     @cached(max_entries=50000)
     def _get_state_group_for_event(self, event_id):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={"event_id": event_id},
             retcol="state_group",
@@ -661,7 +665,7 @@ class StateGroupWorkerStore(
     def _get_state_group_for_events(self, event_ids):
         """Returns mapping event_id -> state_group
         """
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
             iterable=event_ids,
@@ -902,7 +906,7 @@ class StateGroupWorkerStore(
 
             state_group = self.database_engine.get_next_state_group_id(txn)
 
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={"id": state_group, "room_id": room_id, "event_id": event_id},
@@ -911,7 +915,7 @@ class StateGroupWorkerStore(
             # We persist as a delta if we can, while also ensuring the chain
             # of deltas isn't tooo long, as otherwise read performance degrades.
             if prev_group:
-                is_in_db = self.simple_select_one_onecol_txn(
+                is_in_db = self.db.simple_select_one_onecol_txn(
                     txn,
                     table="state_groups",
                     keyvalues={"id": prev_group},
@@ -926,13 +930,13 @@ class StateGroupWorkerStore(
 
                 potential_hops = self._count_state_group_hops_txn(txn, prev_group)
             if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                self.simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="state_group_edges",
                     values={"state_group": state_group, "prev_state_group": prev_group},
                 )
 
-                self.simple_insert_many_txn(
+                self.db.simple_insert_many_txn(
                     txn,
                     table="state_groups_state",
                     values=[
@@ -947,7 +951,7 @@ class StateGroupWorkerStore(
                     ],
                 )
             else:
-                self.simple_insert_many_txn(
+                self.db.simple_insert_many_txn(
                     txn,
                     table="state_groups_state",
                     values=[
@@ -993,7 +997,7 @@ class StateGroupWorkerStore(
 
             return state_group
 
-        return self.runInteraction("store_state_group", _store_state_group_txn)
+        return self.db.runInteraction("store_state_group", _store_state_group_txn)
 
     @defer.inlineCallbacks
     def get_referenced_state_groups(self, state_groups):
@@ -1007,7 +1011,7 @@ class StateGroupWorkerStore(
             referenced.
         """
 
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="event_to_state_groups",
             column="state_group",
             iterable=state_groups,
@@ -1019,32 +1023,30 @@ class StateGroupWorkerStore(
         return set(row["state_group"] for row in rows)
 
 
-class StateBackgroundUpdateStore(
-    StateGroupBackgroundUpdateStore, BackgroundUpdateStore
-):
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
 
-    def __init__(self, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        self.db.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
             self._background_deduplicate_state,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
         )
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             self.CURRENT_STATE_INDEX_UPDATE_NAME,
             index_name="current_state_events_member_index",
             table="current_state_events",
             columns=["state_key"],
             where_clause="type='m.room.member'",
         )
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
             index_name="event_to_state_groups_sg_index",
             table="event_to_state_groups",
@@ -1065,7 +1067,7 @@ class StateBackgroundUpdateStore(
         batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
 
         if max_group is None:
-            rows = yield self.execute(
+            rows = yield self.db.execute(
                 "_background_deduplicate_state",
                 None,
                 "SELECT coalesce(max(id), 0) FROM state_groups",
@@ -1135,13 +1137,13 @@ class StateBackgroundUpdateStore(
                             if prev_state.get(key, None) != value
                         }
 
-                        self.simple_delete_txn(
+                        self.db.simple_delete_txn(
                             txn,
                             table="state_group_edges",
                             keyvalues={"state_group": state_group},
                         )
 
-                        self.simple_insert_txn(
+                        self.db.simple_insert_txn(
                             txn,
                             table="state_group_edges",
                             values={
@@ -1150,13 +1152,13 @@ class StateBackgroundUpdateStore(
                             },
                         )
 
-                        self.simple_delete_txn(
+                        self.db.simple_delete_txn(
                             txn,
                             table="state_groups_state",
                             keyvalues={"state_group": state_group},
                         )
 
-                        self.simple_insert_many_txn(
+                        self.db.simple_insert_many_txn(
                             txn,
                             table="state_groups_state",
                             values=[
@@ -1177,18 +1179,18 @@ class StateBackgroundUpdateStore(
                 "max_group": max_group,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
             )
 
             return False, batch_size
 
-        finished, result = yield self.runInteraction(
+        finished, result = yield self.db.runInteraction(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
         )
 
         if finished:
-            yield self._end_background_update(
+            yield self.db.updates._end_background_update(
                 self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
             )
 
@@ -1218,9 +1220,9 @@ class StateBackgroundUpdateStore(
                 )
                 txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
 
-        yield self.runWithConnection(reindex_txn)
+        yield self.db.runWithConnection(reindex_txn)
 
-        yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+        yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
 
         return 1
 
@@ -1244,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateStore, self).__init__(database, db_conn, hs)
 
     def _store_event_state_mappings_txn(
         self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
@@ -1263,7 +1265,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
 
             state_groups[event.event_id] = context.state_group
 
-        self.simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_to_state_groups",
             values=[
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 03b908026b..12c982cb26 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -98,14 +98,14 @@ class StateDeltasStore(SQLBaseStore):
                 ORDER BY stream_id ASC
             """
             txn.execute(sql, (prev_stream_id, clipped_stream_id))
-            return clipped_stream_id, self.cursor_to_dict(txn)
+            return clipped_stream_id, self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn
         )
 
     def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
-        return self.simple_select_one_onecol_txn(
+        return self.db.simple_select_one_onecol_txn(
             txn,
             table="current_state_delta_stream",
             keyvalues={},
@@ -113,7 +113,7 @@ class StateDeltasStore(SQLBaseStore):
         )
 
     def get_max_stream_id_in_current_state_deltas(self):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_max_stream_id_in_current_state_deltas",
             self._get_max_stream_id_in_current_state_deltas_txn,
         )
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index b306478824..7bc186e9a1 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.caches.descriptors import cached
 
@@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, db_conn, hs):
-        super(StatsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StatsStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
@@ -68,17 +69,17 @@ class StatsStore(StateDeltasStore):
 
         self.stats_delta_processing_lock = DeferredLock()
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_stats_process_rooms", self._populate_stats_process_rooms
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_stats_process_users", self._populate_stats_process_users
         )
         # we no longer need to perform clean-up, but we will give ourselves
         # the potential to reintroduce it in the future – so documentation
         # will still encourage the use of this no-op handler.
-        self.register_noop_background_update("populate_stats_cleanup")
-        self.register_noop_background_update("populate_stats_prepare")
+        self.db.updates.register_noop_background_update("populate_stats_cleanup")
+        self.db.updates.register_noop_background_update("populate_stats_prepare")
 
     def quantise_stats_time(self, ts):
         """
@@ -102,7 +103,7 @@ class StatsStore(StateDeltasStore):
         This is a background update which regenerates statistics for users.
         """
         if not self.stats_enabled:
-            yield self._end_background_update("populate_stats_process_users")
+            yield self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         last_user_id = progress.get("last_user_id", "")
@@ -117,22 +118,22 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_user_id, batch_size))
             return [r for r, in txn]
 
-        users_to_work_on = yield self.runInteraction(
+        users_to_work_on = yield self.db.runInteraction(
             "_populate_stats_process_users", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not users_to_work_on:
-            yield self._end_background_update("populate_stats_process_users")
+            yield self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         for user_id in users_to_work_on:
             yield self._calculate_and_set_initial_state_for_user(user_id)
             progress["last_user_id"] = user_id
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_stats_process_users",
-            self._background_update_progress_txn,
+            self.db.updates._background_update_progress_txn,
             "populate_stats_process_users",
             progress,
         )
@@ -145,7 +146,7 @@ class StatsStore(StateDeltasStore):
         This is a background update which regenerates statistics for rooms.
         """
         if not self.stats_enabled:
-            yield self._end_background_update("populate_stats_process_rooms")
+            yield self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         last_room_id = progress.get("last_room_id", "")
@@ -160,22 +161,22 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_room_id, batch_size))
             return [r for r, in txn]
 
-        rooms_to_work_on = yield self.runInteraction(
+        rooms_to_work_on = yield self.db.runInteraction(
             "populate_stats_rooms_get_batch", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self._end_background_update("populate_stats_process_rooms")
+            yield self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         for room_id in rooms_to_work_on:
             yield self._calculate_and_set_initial_state_for_room(room_id)
             progress["last_room_id"] = room_id
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "_populate_stats_process_rooms",
-            self._background_update_progress_txn,
+            self.db.updates._background_update_progress_txn,
             "populate_stats_process_rooms",
             progress,
         )
@@ -186,7 +187,7 @@ class StatsStore(StateDeltasStore):
         """
         Returns the stats processor positions.
         """
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="stats_incremental_position",
             keyvalues={},
             retcol="stream_id",
@@ -215,7 +216,7 @@ class StatsStore(StateDeltasStore):
             if field and "\0" in field:
                 fields[col] = None
 
-        return self.simple_upsert(
+        return self.db.simple_upsert(
             table="room_stats_state",
             keyvalues={"room_id": room_id},
             values=fields,
@@ -236,7 +237,7 @@ class StatsStore(StateDeltasStore):
             Deferred[list[dict]], where the dict has the keys of
             ABSOLUTE_STATS_FIELDS[stats_type],  and "bucket_size" and "end_ts".
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_statistics_for_subject",
             self._get_statistics_for_subject_txn,
             stats_type,
@@ -257,7 +258,7 @@ class StatsStore(StateDeltasStore):
             ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
         )
 
-        slice_list = self.simple_select_list_paginate_txn(
+        slice_list = self.db.simple_select_list_paginate_txn(
             txn,
             table + "_historical",
             "end_ts",
@@ -282,7 +283,7 @@ class StatsStore(StateDeltasStore):
                 "name", "topic", "canonical_alias", "avatar", "join_rules",
                 "history_visibility"
         """
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             "room_stats_state",
             {"room_id": room_id},
             retcols=(
@@ -308,7 +309,7 @@ class StatsStore(StateDeltasStore):
         """
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             "%s_current" % (table,),
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
@@ -344,14 +345,14 @@ class StatsStore(StateDeltasStore):
                         complete_with_stream_id=stream_id,
                     )
 
-            self.simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="stats_incremental_position",
                 keyvalues={},
                 updatevalues={"stream_id": stream_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "bulk_update_stats_delta", _bulk_update_stats_delta_txn
         )
 
@@ -382,7 +383,7 @@ class StatsStore(StateDeltasStore):
                 Does not work with per-slice fields.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_stats_delta",
             self._update_stats_delta_txn,
             ts,
@@ -517,17 +518,17 @@ class StatsStore(StateDeltasStore):
         else:
             self.database_engine.lock_table(txn, table)
             retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
-            current_row = self.simple_select_one_txn(
+            current_row = self.db.simple_select_one_txn(
                 txn, table, keyvalues, retcols, allow_none=True
             )
             if current_row is None:
                 merged_dict = {**keyvalues, **absolutes, **additive_relatives}
-                self.simple_insert_txn(txn, table, merged_dict)
+                self.db.simple_insert_txn(txn, table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():
                     current_row[key] += val
                 current_row.update(absolutes)
-                self.simple_update_one_txn(txn, table, keyvalues, current_row)
+                self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
 
     def _upsert_copy_from_table_with_additive_relatives_txn(
         self,
@@ -614,11 +615,11 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, qargs)
         else:
             self.database_engine.lock_table(txn, into_table)
-            src_row = self.simple_select_one_txn(
+            src_row = self.db.simple_select_one_txn(
                 txn, src_table, keyvalues, copy_columns
             )
             all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
-            dest_current_row = self.simple_select_one_txn(
+            dest_current_row = self.db.simple_select_one_txn(
                 txn,
                 into_table,
                 keyvalues=all_dest_keyvalues,
@@ -634,11 +635,11 @@ class StatsStore(StateDeltasStore):
                     **src_row,
                     **additive_relatives,
                 }
-                self.simple_insert_txn(txn, into_table, merged_dict)
+                self.db.simple_insert_txn(txn, into_table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():
                     src_row[key] = dest_current_row[key] + val
-                self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+                self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
 
     def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
         """Fetches the counts of events in the given range of stream IDs.
@@ -652,7 +653,7 @@ class StatsStore(StateDeltasStore):
             changes.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "stats_incremental_total_events_and_bytes",
             self.get_changes_room_total_events_and_bytes_txn,
             min_pos,
@@ -735,7 +736,7 @@ class StatsStore(StateDeltasStore):
         def _fetch_current_state_stats(txn):
             pos = self.get_room_max_stream_ordering()
 
-            rows = self.simple_select_many_txn(
+            rows = self.db.simple_select_many_txn(
                 txn,
                 table="current_state_events",
                 column="type",
@@ -791,7 +792,7 @@ class StatsStore(StateDeltasStore):
             current_state_events_count,
             users_in_room,
             pos,
-        ) = yield self.runInteraction(
+        ) = yield self.db.runInteraction(
             "get_initial_state_for_room", _fetch_current_state_stats
         )
 
@@ -866,7 +867,7 @@ class StatsStore(StateDeltasStore):
             (count,) = txn.fetchone()
             return count, pos
 
-        joined_rooms, pos = yield self.runInteraction(
+        joined_rooms, pos = yield self.db.runInteraction(
             "calculate_and_set_initial_state_for_user",
             _calculate_and_set_initial_state_for_user_txn,
         )
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 60487c4559..140da8dad6 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -47,6 +47,7 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -251,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(StreamWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         events_max = self.get_room_max_stream_ordering()
-        event_cache_prefill, min_event_val = self.get_cache_dict(
+        event_cache_prefill, min_event_val = self.db.get_cache_dict(
             db_conn,
             "events",
             entity_column="room_id",
@@ -400,7 +401,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
             return rows
 
-        rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+        rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
 
         ret = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
@@ -450,7 +451,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return rows
 
-        rows = yield self.runInteraction("get_membership_changes_for_user", f)
+        rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
 
         ret = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
@@ -511,7 +512,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         end_token = RoomStreamToken.parse(end_token)
 
-        rows, token = yield self.runInteraction(
+        rows, token = yield self.db.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
@@ -548,7 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn.execute(sql, (room_id, stream_ordering))
             return txn.fetchone()
 
-        return self.runInteraction("get_room_event_after_stream_ordering", _f)
+        return self.db.runInteraction("get_room_event_after_stream_ordering", _f)
 
     @defer.inlineCallbacks
     def get_room_events_max_id(self, room_id=None):
@@ -562,7 +563,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if room_id is None:
             return "s%d" % (token,)
         else:
-            topo = yield self.runInteraction(
+            topo = yield self.db.runInteraction(
                 "_get_max_topological_txn", self._get_max_topological_txn, room_id
             )
             return "t%d-%d" % (topo, token)
@@ -576,7 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred "s%d" stream token.
         """
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
         ).addCallback(lambda row: "s%d" % (row,))
 
@@ -589,7 +590,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred "t%d-%d" topological token.
         """
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
             retcols=("stream_ordering", "topological_ordering"),
@@ -613,7 +614,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "SELECT coalesce(max(topological_ordering), 0) FROM events"
             " WHERE room_id = ? AND stream_ordering < ?"
         )
-        return self.execute(
+        return self.db.execute(
             "get_max_topological_token", None, sql, room_id, stream_key
         ).addCallback(lambda r: r[0][0] if r else 0)
 
@@ -667,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = yield self.runInteraction(
+        results = yield self.db.runInteraction(
             "get_events_around",
             self._get_events_around_txn,
             room_id,
@@ -709,7 +710,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = self.simple_select_one_txn(
+        results = self.db.simple_select_one_txn(
             txn,
             "events",
             keyvalues={"event_id": event_id, "room_id": room_id},
@@ -788,7 +789,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.runInteraction(
+        upper_bound, event_ids = yield self.db.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
@@ -797,7 +798,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return upper_bound, events
 
     def get_federation_out_pos(self, typ):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="federation_stream_position",
             retcol="stream_id",
             keyvalues={"type": typ},
@@ -805,7 +806,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     def update_federation_out_pos(self, typ, stream_id):
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="federation_stream_position",
             keyvalues={"type": typ},
             updatevalues={"stream_id": stream_id},
@@ -956,7 +957,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if to_key:
             to_key = RoomStreamToken.parse(to_key)
 
-        rows, token = yield self.runInteraction(
+        rows, token = yield self.db.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 85012403be..2aa1bafd48 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             tag strings to tag content.
         """
 
-        deferred = self.simple_select_list(
+        deferred = self.db.simple_select_list(
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
@@ -78,7 +78,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        tag_ids = yield self.runInteraction(
+        tag_ids = yield self.db.runInteraction(
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
@@ -98,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         batch_size = 50
         results = []
         for i in range(0, len(tag_ids), batch_size):
-            tags = yield self.runInteraction(
+            tags = yield self.db.runInteraction(
                 "get_all_updated_tag_content",
                 get_tag_content,
                 tag_ids[i : i + batch_size],
@@ -135,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
         if not changed:
             return {}
 
-        room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
+        room_ids = yield self.db.runInteraction(
+            "get_updated_tags", get_updated_tags_txn
+        )
 
         results = {}
         if room_ids:
@@ -153,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         Returns:
             A deferred list of string tags.
         """
-        return self.simple_select_list(
+        return self.db.simple_select_list(
             table="room_tags",
             keyvalues={"user_id": user_id, "room_id": room_id},
             retcols=("tag", "content"),
@@ -178,7 +180,7 @@ class TagsStore(TagsWorkerStore):
         content_json = json.dumps(content)
 
         def add_tag_txn(txn, next_id):
-            self.simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="room_tags",
                 keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -187,7 +189,7 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction("add_tag", add_tag_txn, next_id)
+            yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
@@ -210,7 +212,7 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
+            yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index c162f3ea16..5b07c2fbc0 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
@@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
     """
 
-    def __init__(self, db_conn, hs):
-        super(TransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TransactionStore, self).__init__(database, db_conn, hs)
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
@@ -77,7 +78,7 @@ class TransactionStore(SQLBaseStore):
             this transaction or a 2-tuple of (int, dict)
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_received_txn_response",
             self._get_received_txn_response,
             transaction_id,
@@ -85,7 +86,7 @@ class TransactionStore(SQLBaseStore):
         )
 
     def _get_received_txn_response(self, txn, transaction_id, origin):
-        result = self.simple_select_one_txn(
+        result = self.db.simple_select_one_txn(
             txn,
             table="received_transactions",
             keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -119,7 +120,7 @@ class TransactionStore(SQLBaseStore):
             response_json (str)
         """
 
-        return self.simple_insert(
+        return self.db.simple_insert(
             table="received_transactions",
             values={
                 "transaction_id": transaction_id,
@@ -148,7 +149,7 @@ class TransactionStore(SQLBaseStore):
         if result is not SENTINEL:
             return result
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings,
             destination,
@@ -160,7 +161,7 @@ class TransactionStore(SQLBaseStore):
         return result
 
     def _get_destination_retry_timings(self, txn, destination):
-        result = self.simple_select_one_txn(
+        result = self.db.simple_select_one_txn(
             txn,
             table="destinations",
             keyvalues={"destination": destination},
@@ -187,7 +188,7 @@ class TransactionStore(SQLBaseStore):
         """
 
         self._destination_retry_cache.pop(destination, None)
-        return self.runInteraction(
+        return self.db.runInteraction(
             "set_destination_retry_timings",
             self._set_destination_retry_timings,
             destination,
@@ -227,7 +228,7 @@ class TransactionStore(SQLBaseStore):
         # We need to be careful here as the data may have changed from under us
         # due to a worker setting the timings.
 
-        prev_row = self.simple_select_one_txn(
+        prev_row = self.db.simple_select_one_txn(
             txn,
             table="destinations",
             keyvalues={"destination": destination},
@@ -236,7 +237,7 @@ class TransactionStore(SQLBaseStore):
         )
 
         if not prev_row:
-            self.simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="destinations",
                 values={
@@ -247,7 +248,7 @@ class TransactionStore(SQLBaseStore):
                 },
             )
         elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
-            self.simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 "destinations",
                 keyvalues={"destination": destination},
@@ -270,4 +271,6 @@ class TransactionStore(SQLBaseStore):
         def _cleanup_transactions_txn(txn):
             txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
 
-        return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
+        return self.db.runInteraction(
+            "_cleanup_transactions", _cleanup_transactions_txn
+        )
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 1a85aabbfb..90c180ec6d 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -19,9 +19,9 @@ import re
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
-from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.state import StateFilter
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
@@ -32,30 +32,30 @@ logger = logging.getLogger(__name__)
 TEMP_TABLE = "_temp_populate_user_directory"
 
 
-class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
+class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_createtables",
             self._populate_user_directory_createtables,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_process_rooms",
             self._populate_user_directory_process_rooms,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_process_users",
             self._populate_user_directory_process_users,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_cleanup", self._populate_user_directory_cleanup
         )
 
@@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             """
             txn.execute(sql)
             rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
-            self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+            self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
             del rooms
 
             # If search all users is on, get all the users we want to add.
@@ -100,15 +100,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                 txn.execute("SELECT name FROM users")
                 users = [{"user_id": x[0]} for x in txn.fetchall()]
 
-                self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+                self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
 
         new_pos = yield self.get_max_stream_id_in_current_state_deltas()
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_user_directory_temp_build", _make_staging_area
         )
-        yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+        yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
 
-        yield self._end_background_update("populate_user_directory_createtables")
+        yield self.db.updates._end_background_update(
+            "populate_user_directory_createtables"
+        )
         return 1
 
     @defer.inlineCallbacks
@@ -116,7 +118,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         """
         Update the user directory stream position, then clean up the old tables.
         """
-        position = yield self.simple_select_one_onecol(
+        position = yield self.db.simple_select_one_onecol(
             TEMP_TABLE + "_position", None, "position"
         )
         yield self.update_user_directory_stream_pos(position)
@@ -126,11 +128,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_user_directory_cleanup", _delete_staging_area
         )
 
-        yield self._end_background_update("populate_user_directory_cleanup")
+        yield self.db.updates._end_background_update("populate_user_directory_cleanup")
         return 1
 
     @defer.inlineCallbacks
@@ -170,13 +172,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
             return rooms_to_work_on
 
-        rooms_to_work_on = yield self.runInteraction(
+        rooms_to_work_on = yield self.db.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self._end_background_update("populate_user_directory_process_rooms")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_rooms"
+            )
             return 1
 
         logger.info(
@@ -243,12 +247,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                         to_insert.clear()
 
             # We've finished a room. Delete it from the table.
-            yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+            yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "populate_user_directory",
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 "populate_user_directory_process_rooms",
                 progress,
             )
@@ -267,7 +271,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         If search_all_users is enabled, add all of the users to the user directory.
         """
         if not self.hs.config.user_directory_search_all_users:
-            yield self._end_background_update("populate_user_directory_process_users")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_users"
+            )
             return 1
 
         def _get_next_batch(txn):
@@ -291,13 +297,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
             return users_to_work_on
 
-        users_to_work_on = yield self.runInteraction(
+        users_to_work_on = yield self.db.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more users -- complete the transaction.
         if not users_to_work_on:
-            yield self._end_background_update("populate_user_directory_process_users")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_users"
+            )
             return 1
 
         logger.info(
@@ -312,12 +320,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             )
 
             # We've finished processing a user. Delete it from the table.
-            yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+            yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "populate_user_directory",
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 "populate_user_directory_process_users",
                 progress,
             )
@@ -361,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         """
 
         def _update_profile_in_user_dir_txn(txn):
-            new_entry = self.simple_upsert_txn(
+            new_entry = self.db.simple_upsert_txn(
                 txn,
                 table="user_directory",
                 keyvalues={"user_id": user_id},
@@ -435,7 +443,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                         )
             elif isinstance(self.database_engine, Sqlite3Engine):
                 value = "%s %s" % (user_id, display_name) if display_name else user_id
-                self.simple_upsert_txn(
+                self.db.simple_upsert_txn(
                     txn,
                     table="user_directory_search",
                     keyvalues={"user_id": user_id},
@@ -448,7 +456,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
@@ -462,7 +470,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         """
 
         def _add_users_who_share_room_txn(txn):
-            self.simple_upsert_many_txn(
+            self.db.simple_upsert_many_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 key_names=["user_id", "other_user_id", "room_id"],
@@ -474,7 +482,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                 value_values=None,
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
@@ -489,7 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
         def _add_users_in_public_rooms_txn(txn):
 
-            self.simple_upsert_many_txn(
+            self.db.simple_upsert_many_txn(
                 txn,
                 table="users_in_public_rooms",
                 key_names=["user_id", "room_id"],
@@ -498,7 +506,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                 value_values=None,
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_users_in_public_rooms", _add_users_in_public_rooms_txn
         )
 
@@ -513,13 +521,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
 
     @cached()
     def get_user_in_directory(self, user_id):
-        return self.simple_select_one(
+        return self.db.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
             retcols=("display_name", "avatar_url"),
@@ -528,7 +536,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         )
 
     def update_user_directory_stream_pos(self, stream_id):
-        return self.simple_update_one(
+        return self.db.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
             updatevalues={"stream_id": stream_id},
@@ -542,47 +550,47 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryStore, self).__init__(database, db_conn, hs)
 
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="user_directory", keyvalues={"user_id": user_id}
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"user_id": user_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"other_user_id": user_id},
             )
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+        return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
 
     @defer.inlineCallbacks
     def get_users_in_dir_due_to_room(self, room_id):
         """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
-        user_ids_share_pub = yield self.simple_select_onecol(
+        user_ids_share_pub = yield self.db.simple_select_onecol(
             table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
             retcol="user_id",
             desc="get_users_in_dir_due_to_room",
         )
 
-        user_ids_share_priv = yield self.simple_select_onecol(
+        user_ids_share_priv = yield self.db.simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"room_id": room_id},
             retcol="other_user_id",
@@ -605,23 +613,23 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         """
 
         def _remove_user_who_share_room_txn(txn):
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"other_user_id": user_id, "room_id": room_id},
             )
-            self.simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_in_public_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
@@ -636,14 +644,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         Returns:
             list: user_id
         """
-        rows = yield self.simple_select_onecol(
+        rows = yield self.db.simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
             desc="get_rooms_user_is_in",
         )
 
-        pub_rows = yield self.simple_select_onecol(
+        pub_rows = yield self.db.simple_select_onecol(
             table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
@@ -674,14 +682,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             ) f2 USING (room_id)
         """
 
-        rows = yield self.execute(
+        rows = yield self.db.execute(
             "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
         )
 
         return [room_id for room_id, in rows]
 
     def get_user_directory_stream_pos(self):
-        return self.simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
             retcol="stream_id",
@@ -786,7 +794,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args)
+        results = yield self.db.execute(
+            "search_user_dir", self.db.cursor_to_dict, sql, *args
+        )
 
         limited = len(results) > limit
 
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index 37860af070..af8025bc17 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if the user has requested erasure
         """
-        return self.simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="erased_users",
             keyvalues={"user_id": user_id},
             retcol="1",
@@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore):
         # iterate it multiple times, and (b) avoiding duplicates.
         user_ids = tuple(set(user_ids))
 
-        rows = yield self.simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="erased_users",
             column="user_id",
             iterable=user_ids,
@@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.runInteraction("mark_user_erased", f)
+        return self.db.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
new file mode 100644
index 0000000000..ec19ae1d9d
--- /dev/null
+++ b/synapse/storage/database.py
@@ -0,0 +1,1490 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+import time
+from typing import Iterable, Tuple
+
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+
+from prometheus_client import Histogram
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.util.stringutils import exception_to_unicode
+
+# import a function which will return a monotonic time, in seconds
+try:
+    # on python 3, use time.monotonic, since time.clock can go backwards
+    from time import monotonic as monotonic_time
+except ImportError:
+    # ... but python 2 doesn't have it
+    from time import clock as monotonic_time
+
+logger = logging.getLogger(__name__)
+
+try:
+    MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+    # python 3 does not have a maximum int value
+    MAX_TXN_ID = 2 ** 63 - 1
+
+sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
+
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
+
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+
+
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+    "user_ips": "user_ips_device_unique_index",
+    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+    "event_search": "event_search_event_id_idx",
+}
+
+
+class LoggingTransaction(object):
+    """An object that almost-transparently proxies for the 'txn' object
+    passed to the constructor. Adds logging and metrics to the .execute()
+    method.
+
+    Args:
+        txn: The database transcation object to wrap.
+        name (str): The name of this transactions for logging.
+        database_engine (Sqlite3Engine|PostgresEngine)
+        after_callbacks(list|None): A list that callbacks will be appended to
+            that have been added by `call_after` which should be run on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
+        exception_callbacks(list|None): A list that callbacks will be appended
+            to that have been added by `call_on_exception` which should be run
+            if transaction ends with an error. None indicates that no callbacks
+            should be allowed to be scheduled to run.
+    """
+
+    __slots__ = [
+        "txn",
+        "name",
+        "database_engine",
+        "after_callbacks",
+        "exception_callbacks",
+    ]
+
+    def __init__(
+        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+    ):
+        object.__setattr__(self, "txn", txn)
+        object.__setattr__(self, "name", name)
+        object.__setattr__(self, "database_engine", database_engine)
+        object.__setattr__(self, "after_callbacks", after_callbacks)
+        object.__setattr__(self, "exception_callbacks", exception_callbacks)
+
+    def call_after(self, callback, *args, **kwargs):
+        """Call the given callback on the main twisted thread after the
+        transaction has finished. Used to invalidate the caches on the
+        correct thread.
+        """
+        self.after_callbacks.append((callback, args, kwargs))
+
+    def call_on_exception(self, callback, *args, **kwargs):
+        self.exception_callbacks.append((callback, args, kwargs))
+
+    def __getattr__(self, name):
+        return getattr(self.txn, name)
+
+    def __setattr__(self, name, value):
+        setattr(self.txn, name, value)
+
+    def __iter__(self):
+        return self.txn.__iter__()
+
+    def execute_batch(self, sql, args):
+        if isinstance(self.database_engine, PostgresEngine):
+            from psycopg2.extras import execute_batch
+
+            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+        else:
+            for val in args:
+                self.execute(sql, val)
+
+    def execute(self, sql, *args):
+        self._do_execute(self.txn.execute, sql, *args)
+
+    def executemany(self, sql, *args):
+        self._do_execute(self.txn.executemany, sql, *args)
+
+    def _make_sql_one_line(self, sql):
+        "Strip newlines out of SQL so that the loggers in the DB are on one line"
+        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
+    def _do_execute(self, func, sql, *args):
+        sql = self._make_sql_one_line(sql)
+
+        # TODO(paul): Maybe use 'info' and 'debug' for values?
+        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+
+        sql = self.database_engine.convert_param_style(sql)
+        if args:
+            try:
+                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
+            except Exception:
+                # Don't let logging failures stop SQL from working
+                pass
+
+        start = time.time()
+
+        try:
+            return func(sql, *args)
+        except Exception as e:
+            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+            raise
+        finally:
+            secs = time.time() - start
+            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+            sql_query_timer.labels(sql.split()[0]).observe(secs)
+
+
+class PerformanceCounters(object):
+    def __init__(self):
+        self.current_counters = {}
+        self.previous_counters = {}
+
+    def update(self, key, duration_secs):
+        count, cum_time = self.current_counters.get(key, (0, 0))
+        count += 1
+        cum_time += duration_secs
+        self.current_counters[key] = (count, cum_time)
+
+    def interval(self, interval_duration_secs, limit=3):
+        counters = []
+        for name, (count, cum_time) in iteritems(self.current_counters):
+            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
+            counters.append(
+                (
+                    (cum_time - prev_time) / interval_duration_secs,
+                    count - prev_count,
+                    name,
+                )
+            )
+
+        self.previous_counters = dict(self.current_counters)
+
+        counters.sort(reverse=True)
+
+        top_n_counters = ", ".join(
+            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
+            for ratio, count, name in counters[:limit]
+        )
+
+        return top_n_counters
+
+
+class Database(object):
+    """Wraps a single physical database and connection pool.
+
+    A single database may be used by multiple data stores.
+    """
+
+    _TXN_ID = 0
+
+    def __init__(self, hs):
+        self.hs = hs
+        self._clock = hs.get_clock()
+        self._db_pool = hs.get_db_pool()
+
+        self.updates = BackgroundUpdater(hs, self)
+
+        self._previous_txn_total_time = 0
+        self._current_txn_total_time = 0
+        self._previous_loop_ts = 0
+
+        # TODO(paul): These can eventually be removed once the metrics code
+        #   is running in mainline, and we have some nice monitoring frontends
+        #   to watch it
+        self._txn_perf_counters = PerformanceCounters()
+
+        self.engine = hs.database_engine
+
+        # A set of tables that are not safe to use native upserts in.
+        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+        # We add the user_directory_search table to the blacklist on SQLite
+        # because the existing search table does not have an index, making it
+        # unsafe to use native upserts.
+        if isinstance(self.engine, Sqlite3Engine):
+            self._unsafe_to_upsert_tables.add("user_directory_search")
+
+        if self.engine.can_native_upsert:
+            # Check ASAP (and then later, every 1s) to see if we have finished
+            # background updates of tables that aren't safe to update.
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert,
+            )
+
+    @defer.inlineCallbacks
+    def _check_safe_to_upsert(self):
+        """
+        Is it safe to use native UPSERT?
+
+        If there are background updates, we will need to wait, as they may be
+        the addition of indexes that set the UNIQUE constraint that we require.
+
+        If the background updates have not completed, wait 15 sec and check again.
+        """
+        updates = yield self.simple_select_list(
+            "background_updates",
+            keyvalues=None,
+            retcols=["update_name"],
+            desc="check_background_updates",
+        )
+        updates = [x["update_name"] for x in updates]
+
+        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+            if update_name not in updates:
+                logger.debug("Now safe to upsert in %s", table)
+                self._unsafe_to_upsert_tables.discard(table)
+
+        # If there's any updates still running, reschedule to run.
+        if updates:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert,
+            )
+
+    def start_profiling(self):
+        self._previous_loop_ts = monotonic_time()
+
+        def loop():
+            curr = self._current_txn_total_time
+            prev = self._previous_txn_total_time
+            self._previous_txn_total_time = curr
+
+            time_now = monotonic_time()
+            time_then = self._previous_loop_ts
+            self._previous_loop_ts = time_now
+
+            duration = time_now - time_then
+            ratio = (curr - prev) / duration
+
+            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
+
+            perf_logger.info(
+                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
+            )
+
+        self._clock.looping_call(loop, 10000)
+
+    def new_transaction(
+        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+    ):
+        start = monotonic_time()
+        txn_id = self._TXN_ID
+
+        # We don't really need these to be unique, so lets stop it from
+        # growing really large.
+        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
+
+        name = "%s-%x" % (desc, txn_id)
+
+        transaction_logger.debug("[TXN START] {%s}", name)
+
+        try:
+            i = 0
+            N = 5
+            while True:
+                cursor = LoggingTransaction(
+                    conn.cursor(),
+                    name,
+                    self.engine,
+                    after_callbacks,
+                    exception_callbacks,
+                )
+                try:
+                    r = func(cursor, *args, **kwargs)
+                    conn.commit()
+                    return r
+                except self.engine.module.OperationalError as e:
+                    # This can happen if the database disappears mid
+                    # transaction.
+                    logger.warning(
+                        "[TXN OPERROR] {%s} %s %d/%d",
+                        name,
+                        exception_to_unicode(e),
+                        i,
+                        N,
+                    )
+                    if i < N:
+                        i += 1
+                        try:
+                            conn.rollback()
+                        except self.engine.module.Error as e1:
+                            logger.warning(
+                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
+                            )
+                        continue
+                    raise
+                except self.engine.module.DatabaseError as e:
+                    if self.engine.is_deadlock(e):
+                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        if i < N:
+                            i += 1
+                            try:
+                                conn.rollback()
+                            except self.engine.module.Error as e1:
+                                logger.warning(
+                                    "[TXN EROLL] {%s} %s",
+                                    name,
+                                    exception_to_unicode(e1),
+                                )
+                            continue
+                    raise
+                finally:
+                    # we're either about to retry with a new cursor, or we're about to
+                    # release the connection. Once we release the connection, it could
+                    # get used for another query, which might do a conn.rollback().
+                    #
+                    # In the latter case, even though that probably wouldn't affect the
+                    # results of this transaction, python's sqlite will reset all
+                    # statements on the connection [1], which will make our cursor
+                    # invalid [2].
+                    #
+                    # In any case, continuing to read rows after commit()ing seems
+                    # dubious from the PoV of ACID transactional semantics
+                    # (sqlite explicitly says that once you commit, you may see rows
+                    # from subsequent updates.)
+                    #
+                    # In psycopg2, cursors are essentially a client-side fabrication -
+                    # all the data is transferred to the client side when the statement
+                    # finishes executing - so in theory we could go on streaming results
+                    # from the cursor, but attempting to do so would make us
+                    # incompatible with sqlite, so let's make sure we're not doing that
+                    # by closing the cursor.
+                    #
+                    # (*named* cursors in psycopg2 are different and are proper server-
+                    # side things, but (a) we don't use them and (b) they are implicitly
+                    # closed by ending the transaction anyway.)
+                    #
+                    # In short, if we haven't finished with the cursor yet, that's a
+                    # problem waiting to bite us.
+                    #
+                    # TL;DR: we're done with the cursor, so we can close it.
+                    #
+                    # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+                    # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+                    cursor.close()
+        except Exception as e:
+            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            raise
+        finally:
+            end = monotonic_time()
+            duration = end - start
+
+            LoggingContext.current_context().add_database_transaction(duration)
+
+            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
+
+            self._current_txn_total_time += duration
+            self._txn_perf_counters.update(desc, duration)
+            sql_txn_timer.labels(desc).observe(duration)
+
+    @defer.inlineCallbacks
+    def runInteraction(self, desc, func, *args, **kwargs):
+        """Starts a transaction on the database and runs a given function
+
+        Arguments:
+            desc (str): description of the transaction, for logging and metrics
+            func (func): callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
+
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        after_callbacks = []
+        exception_callbacks = []
+
+        if LoggingContext.current_context() == LoggingContext.sentinel:
+            logger.warning("Starting db txn '%s' from sentinel context", desc)
+
+        try:
+            result = yield self.runWithConnection(
+                self.new_transaction,
+                desc,
+                after_callbacks,
+                exception_callbacks,
+                func,
+                *args,
+                **kwargs
+            )
+
+            for after_callback, after_args, after_kwargs in after_callbacks:
+                after_callback(*after_args, **after_kwargs)
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            for after_callback, after_args, after_kwargs in exception_callbacks:
+                after_callback(*after_args, **after_kwargs)
+            raise
+
+        return result
+
+    @defer.inlineCallbacks
+    def runWithConnection(self, func, *args, **kwargs):
+        """Wraps the .runWithConnection() method on the underlying db_pool.
+
+        Arguments:
+            func (func): callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        parent_context = LoggingContext.current_context()
+        if parent_context == LoggingContext.sentinel:
+            logger.warning(
+                "Starting db connection from sentinel context: metrics will be lost"
+            )
+            parent_context = None
+
+        start_time = monotonic_time()
+
+        def inner_func(conn, *args, **kwargs):
+            with LoggingContext("runWithConnection", parent_context) as context:
+                sched_duration_sec = monotonic_time() - start_time
+                sql_scheduling_timer.observe(sched_duration_sec)
+                context.add_database_scheduled(sched_duration_sec)
+
+                if self.engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
+
+                return func(conn, *args, **kwargs)
+
+        result = yield make_deferred_yieldable(
+            self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+        )
+
+        return result
+
+    @staticmethod
+    def cursor_to_dict(cursor):
+        """Converts a SQL cursor into an list of dicts.
+
+        Args:
+            cursor : The DBAPI cursor which has executed a query.
+        Returns:
+            A list of dicts where the key is the column header.
+        """
+        col_headers = list(intern(str(column[0])) for column in cursor.description)
+        results = list(dict(zip(col_headers, row)) for row in cursor)
+        return results
+
+    def execute(self, desc, decoder, query, *args):
+        """Runs a single query for a result set.
+
+        Args:
+            decoder - The function which can resolve the cursor results to
+                something meaningful.
+            query - The query string to execute
+            *args - Query args.
+        Returns:
+            The result of decoder(results)
+        """
+
+        def interaction(txn):
+            txn.execute(query, args)
+            if decoder:
+                return decoder(txn)
+            else:
+                return txn.fetchall()
+
+        return self.runInteraction(desc, interaction)
+
+    # "Simple" SQL API methods that operate on a single table with no JOINs,
+    # no complex WHERE clauses, just a dict of values for columns.
+
+    @defer.inlineCallbacks
+    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+        """Executes an INSERT query on the named table.
+
+        Args:
+            table : string giving the table name
+            values : dict of new column names and values for them
+            or_ignore : bool stating whether an exception should be raised
+                when a conflicting row already exists. If True, False will be
+                returned by the function instead
+            desc : string giving a description of the transaction
+
+        Returns:
+            bool: Whether the row was inserted or not. Only useful when
+            `or_ignore` is True
+        """
+        try:
+            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+        except self.engine.module.IntegrityError:
+            # We have to do or_ignore flag at this layer, since we can't reuse
+            # a cursor after we receive an error from the db.
+            if not or_ignore:
+                raise
+            return False
+        return True
+
+    @staticmethod
+    def simple_insert_txn(txn, table, values):
+        keys, vals = zip(*values.items())
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys),
+            ", ".join("?" for _ in keys),
+        )
+
+        txn.execute(sql, vals)
+
+    def simple_insert_many(self, table, values, desc):
+        return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+
+    @staticmethod
+    def simple_insert_many_txn(txn, table, values):
+        if not values:
+            return
+
+        # This is a *slight* abomination to get a list of tuples of key names
+        # and a list of tuples of value names.
+        #
+        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+        #
+        # The sort is to ensure that we don't rely on dictionary iteration
+        # order.
+        keys, vals = zip(
+            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+        )
+
+        for k in keys:
+            if k != keys[0]:
+                raise RuntimeError("All items must have the same keys")
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys[0]),
+            ", ".join("?" for _ in keys[0]),
+        )
+
+        txn.executemany(sql, vals)
+
+    @defer.inlineCallbacks
+    def simple_upsert(
+        self,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        desc="simple_upsert",
+        lock=True,
+    ):
+        """
+
+        `lock` should generally be set to True (the default), but can be set
+        to False if either of the following are true:
+
+        * there is a UNIQUE INDEX on the key columns. In this case a conflict
+          will cause an IntegrityError in which case this function will retry
+          the update.
+
+        * we somehow know that we are the only thread which will be updating
+          this table.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key columns and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            Deferred(None or bool): Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        attempts = 0
+        while True:
+            try:
+                result = yield self.runInteraction(
+                    desc,
+                    self.simple_upsert_txn,
+                    table,
+                    keyvalues,
+                    values,
+                    insertion_values,
+                    lock=lock,
+                )
+                return result
+            except self.engine.module.IntegrityError as e:
+                attempts += 1
+                if attempts >= 5:
+                    # don't retry forever, because things other than races
+                    # can cause IntegrityErrors
+                    raise
+
+                # presumably we raced with another transaction: let's retry.
+                logger.warning(
+                    "IntegrityError when upserting into %s; retrying: %s", table, e
+                )
+
+    def simple_upsert_txn(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Pick the UPSERT method which works best on the platform. Either the
+        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+        Args:
+            txn: The transaction to use.
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            None or bool: Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+            return self.simple_upsert_txn_native_upsert(
+                txn, table, keyvalues, values, insertion_values=insertion_values
+            )
+        else:
+            return self.simple_upsert_txn_emulated(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+                lock=lock,
+            )
+
+    def simple_upsert_txn_emulated(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            bool: Return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        # We need to lock the table :(, unless we're *really* careful
+        if lock:
+            self.engine.lock_table(txn, table)
+
+        def _getwhere(key):
+            # If the value we're passing in is None (aka NULL), we need to use
+            # IS, not =, as NULL = NULL equals NULL (False).
+            if keyvalues[key] is None:
+                return "%s IS ?" % (key,)
+            else:
+                return "%s = ?" % (key,)
+
+        if not values:
+            # If `values` is empty, then all of the values we care about are in
+            # the unique key, so there is nothing to UPDATE. We can just do a
+            # SELECT instead to see if it exists.
+            sql = "SELECT 1 FROM %s WHERE %s" % (
+                table,
+                " AND ".join(_getwhere(k) for k in keyvalues),
+            )
+            sqlargs = list(keyvalues.values())
+            txn.execute(sql, sqlargs)
+            if txn.fetchall():
+                # We have an existing record.
+                return False
+        else:
+            # First try to update.
+            sql = "UPDATE %s SET %s WHERE %s" % (
+                table,
+                ", ".join("%s = ?" % (k,) for k in values),
+                " AND ".join(_getwhere(k) for k in keyvalues),
+            )
+            sqlargs = list(values.values()) + list(keyvalues.values())
+
+            txn.execute(sql, sqlargs)
+            if txn.rowcount > 0:
+                # successfully updated at least one row.
+                return False
+
+        # We didn't find any existing rows, so insert a new one
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
+
+        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+        )
+        txn.execute(sql, list(allvalues.values()))
+        # successfully inserted
+        return True
+
+    def simple_upsert_txn_native_upsert(
+        self, txn, table, keyvalues, values, insertion_values={}
+    ):
+        """
+        Use the native UPSERT functionality in recent PostgreSQL versions.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+        Returns:
+            None
+        """
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(insertion_values)
+
+        if not values:
+            latter = "NOTHING"
+        else:
+            allvalues.update(values)
+            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+
+        sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+            ", ".join(k for k in keyvalues),
+            latter,
+        )
+        txn.execute(sql, list(allvalues.values()))
+
+    def simple_upsert_many_txn(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+            return self.simple_upsert_many_txn_native_upsert(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+        else:
+            return self.simple_upsert_many_txn_emulated(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+
+    def simple_upsert_many_txn_emulated(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, but without native UPSERT support or batching.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        # No value columns, therefore make a blank list so that the following
+        # zip() works correctly.
+        if not value_names:
+            value_values = [() for x in range(len(key_values))]
+
+        for keyv, valv in zip(key_values, value_values):
+            _keys = {x: y for x, y in zip(key_names, keyv)}
+            _vals = {x: y for x, y in zip(value_names, valv)}
+
+            self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+    def simple_upsert_many_txn_native_upsert(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, using batching where possible.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        allnames = []
+        allnames.extend(key_names)
+        allnames.extend(value_names)
+
+        if not value_names:
+            # No value columns, therefore make a blank list so that the
+            # following zip() works correctly.
+            latter = "NOTHING"
+            value_values = [() for x in range(len(key_values))]
+        else:
+            latter = "UPDATE SET " + ", ".join(
+                k + "=EXCLUDED." + k for k in value_names
+            )
+
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+            table,
+            ", ".join(k for k in allnames),
+            ", ".join("?" for _ in allnames),
+            ", ".join(key_names),
+            latter,
+        )
+
+        args = []
+
+        for x, y in zip(key_values, value_values):
+            args.append(tuple(x) + tuple(y))
+
+        return txn.execute_batch(sql, args)
+
+    def simple_select_one(
+        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
+    ):
+        """Executes a SELECT query on the named table, which is expected to
+        return a single row, returning multiple columns from it.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            retcols : list of strings giving the names of the columns to return
+
+            allow_none : If true, return None instead of failing if the SELECT
+              statement returns no rows
+        """
+        return self.runInteraction(
+            desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+        )
+
+    def simple_select_one_onecol(
+        self,
+        table,
+        keyvalues,
+        retcol,
+        allow_none=False,
+        desc="simple_select_one_onecol",
+    ):
+        """Executes a SELECT query on the named table, which is expected to
+        return a single row, returning a single column from it.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            retcol : string giving the name of the column to return
+        """
+        return self.runInteraction(
+            desc,
+            self.simple_select_one_onecol_txn,
+            table,
+            keyvalues,
+            retcol,
+            allow_none=allow_none,
+        )
+
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls, txn, table, keyvalues, retcol, allow_none=False
+    ):
+        ret = cls.simple_select_onecol_txn(
+            txn, table=table, keyvalues=keyvalues, retcol=retcol
+        )
+
+        if ret:
+            return ret[0]
+        else:
+            if allow_none:
+                return None
+            else:
+                raise StoreError(404, "No row found")
+
+    @staticmethod
+    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
+
+        if keyvalues:
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+            txn.execute(sql, list(keyvalues.values()))
+        else:
+            txn.execute(sql)
+
+        return [r[0] for r in txn]
+
+    def simple_select_onecol(
+        self, table, keyvalues, retcol, desc="simple_select_onecol"
+    ):
+        """Executes a SELECT query on the named table, which returns a list
+        comprising of the values of the named column from the selected rows.
+
+        Args:
+            table (str): table name
+            keyvalues (dict|None): column names and values to select the rows with
+            retcol (str): column whos value we wish to retrieve.
+
+        Returns:
+            Deferred: Results in a list
+        """
+        return self.runInteraction(
+            desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+        )
+
+    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            keyvalues (dict[str, Any] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc, self.simple_select_list_txn, table, keyvalues, retcols
+        )
+
+    @classmethod
+    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+        """
+        if keyvalues:
+            sql = "SELECT %s FROM %s WHERE %s" % (
+                ", ".join(retcols),
+                table,
+                " AND ".join("%s = ?" % (k,) for k in keyvalues),
+            )
+            txn.execute(sql, list(keyvalues.values()))
+        else:
+            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+            txn.execute(sql)
+
+        return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def simple_select_many_batch(
+        self,
+        table,
+        column,
+        iterable,
+        retcols,
+        keyvalues={},
+        desc="simple_select_many_batch",
+        batch_size=100,
+    ):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        results = []
+
+        if not iterable:
+            return results
+
+        # iterables can not be sliced, so convert it to a list first
+        it_list = list(iterable)
+
+        chunks = [
+            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
+        ]
+        for chunk in chunks:
+            rows = yield self.runInteraction(
+                desc,
+                self.simple_select_many_txn,
+                table,
+                column,
+                chunk,
+                keyvalues,
+                retcols,
+            )
+
+            results.extend(rows)
+
+        return results
+
+    @classmethod
+    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        if not iterable:
+            return []
+
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
+
+        for key, value in iteritems(keyvalues):
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join(clauses),
+        )
+
+        txn.execute(sql, values)
+        return cls.cursor_to_dict(txn)
+
+    def simple_update(self, table, keyvalues, updatevalues, desc):
+        return self.runInteraction(
+            desc, self.simple_update_txn, table, keyvalues, updatevalues
+        )
+
+    @staticmethod
+    def simple_update_txn(txn, table, keyvalues, updatevalues):
+        if keyvalues:
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+        else:
+            where = ""
+
+        update_sql = "UPDATE %s SET %s %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in updatevalues),
+            where,
+        )
+
+        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
+
+        return txn.rowcount
+
+    def simple_update_one(
+        self, table, keyvalues, updatevalues, desc="simple_update_one"
+    ):
+        """Executes an UPDATE query on the named table, setting new values for
+        columns in a row matching the key values.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            updatevalues : dict giving column names and values to update
+            retcols : optional list of column names to return
+
+        If present, retcols gives a list of column names on which to perform
+        a SELECT statement *before* performing the UPDATE statement. The values
+        of these will be returned in a dict.
+
+        These are performed within the same transaction, allowing an atomic
+        get-and-set.  This can be used to implement compare-and-set by putting
+        the update column in the 'keyvalues' dict as well.
+        """
+        return self.runInteraction(
+            desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+        )
+
+    @classmethod
+    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+        rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
+
+        if rowcount == 0:
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+    @staticmethod
+    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+        select_sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(select_sql, list(keyvalues.values()))
+        row = txn.fetchone()
+
+        if not row:
+            if allow_none:
+                return None
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+        return dict(zip(retcols, row))
+
+    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+        return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+
+    @staticmethod
+    def simple_delete_one_txn(txn, table, keyvalues):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(sql, list(keyvalues.values()))
+        if txn.rowcount == 0:
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+    def simple_delete(self, table, keyvalues, desc):
+        return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+
+    @staticmethod
+    def simple_delete_txn(txn, table, keyvalues):
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(sql, list(keyvalues.values()))
+        return txn.rowcount
+
+    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+        return self.runInteraction(
+            desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+        )
+
+    @staticmethod
+    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+        """Executes a DELETE query on the named table.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+
+        Returns:
+            int: Number rows deleted
+        """
+        if not iterable:
+            return 0
+
+        sql = "DELETE FROM %s" % table
+
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
+
+        for key, value in iteritems(keyvalues):
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+        txn.execute(sql, values)
+
+        return txn.rowcount
+
+    def get_cache_dict(
+        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+    ):
+        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+        # It doesn't really matter how many we get, the StreamChangeCache will
+        # do the right thing to ensure it respects the max size of cache.
+        sql = (
+            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+            " WHERE %(stream)s > ? - %(limit)s"
+            " GROUP BY %(entity)s"
+        ) % {
+            "table": table,
+            "entity": entity_column,
+            "stream": stream_column,
+            "limit": limit,
+        }
+
+        sql = self.engine.convert_param_style(sql)
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (int(max_value),))
+
+        cache = {row[0]: int(row[1]) for row in txn}
+
+        txn.close()
+
+        if cache:
+            min_val = min(itervalues(cache))
+        else:
+            min_val = max_value
+
+        return cache, min_val
+
+    def simple_select_list_paginate(
+        self,
+        table,
+        orderby,
+        start,
+        limit,
+        retcols,
+        filters=None,
+        keyvalues=None,
+        order_direction="ASC",
+        desc="simple_select_list_paginate",
+    ):
+        """
+        Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            filters (dict[str, T] | None):
+                column names and values to filter the rows with, or None to not
+                apply a WHERE ? LIKE ? clause.
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
+            retcols (iterable[str]): the names of the columns to return
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc,
+            self.simple_select_list_paginate_txn,
+            table,
+            orderby,
+            start,
+            limit,
+            retcols,
+            filters=filters,
+            keyvalues=keyvalues,
+            order_direction=order_direction,
+        )
+
+    @classmethod
+    def simple_select_list_paginate_txn(
+        cls,
+        txn,
+        table,
+        orderby,
+        start,
+        limit,
+        retcols,
+        filters=None,
+        keyvalues=None,
+        order_direction="ASC",
+    ):
+        """
+        Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
+        select attributes with exact matches. All constraints are joined together
+        using 'AND'.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
+            retcols (iterable[str]): the names of the columns to return
+            filters (dict[str, T] | None):
+                column names and values to filter the rows with, or None to not
+                apply a WHERE ? LIKE ? clause.
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        if order_direction not in ["ASC", "DESC"]:
+            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
+        where_clause = "WHERE " if filters or keyvalues else ""
+        arg_list = []
+        if filters:
+            where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
+            arg_list += list(filters.values())
+        where_clause += " AND " if filters and keyvalues else ""
+        if keyvalues:
+            where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
+            arg_list += list(keyvalues.values())
+
+        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+            ", ".join(retcols),
+            table,
+            where_clause,
+            orderby,
+            order_direction,
+        )
+        txn.execute(sql, arg_list + [limit, start])
+
+        return cls.cursor_to_dict(txn)
+
+    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+
+        return self.runInteraction(
+            desc, self.simple_search_list_txn, table, term, col, retcols
+        )
+
+    @classmethod
+    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+        if term:
+            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
+            termvalues = ["%%" + term + "%%"]
+            txn.execute(sql, termvalues)
+        else:
+            return 0
+
+        return cls.cursor_to_dict(txn)
+
+
+def make_in_list_sql_clause(
+    database_engine, column: str, iterable: Iterable
+) -> Tuple[str, Iterable]:
+    """Returns an SQL clause that checks the given column is in the iterable.
+
+    On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+    it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+    using the `ANY` form on postgres means that it views queries with
+    different length iterables as the same, helping the query stats.
+
+    Args:
+        database_engine
+        column: Name of the column
+        iterable: The values to check the column against.
+
+    Returns:
+        A tuple of SQL query and the args
+    """
+
+    if database_engine.supports_using_any_list:
+        # This should hopefully be faster, but also makes postgres query
+        # stats easier to understand.
+        return "%s = ANY(?)" % (column,), [list(iterable)]
+    else:
+        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
-    """Cache for snapshots like the response of /initialSync.
-    The response of initialSync only has to be a recent snapshot of the
-    server state. It shouldn't matter to clients if it is a few minutes out
-    of date.
-
-    This caches a deferred response. Until the deferred completes it will be
-    returned from the cache. This means that if the client retries the request
-    while the response is still being computed, that original response will be
-    used rather than trying to compute a new response.
-
-    Once the deferred completes it will removed from the cache after 5 minutes.
-    We delay removing it from the cache because a client retrying its request
-    could race with us finishing computing the response.
-
-    Rather than tracking precisely how long something has been in the cache we
-    keep two generations of completed responses. Every 5 minutes discard the
-    old generation, move the new generation to the old generation, and set the
-    new generation to be empty. This means that a result will be in the cache
-    somewhere between 5 and 10 minutes.
-    """
-
-    DURATION_MS = 5 * 60 * 1000  # Cache results for 5 minutes.
-
-    def __init__(self):
-        self.pending_result_cache = {}  # Request that haven't finished yet.
-        self.prev_result_cache = {}  # The older requests that have finished.
-        self.next_result_cache = {}  # The newer requests that have finished.
-        self.time_last_rotated_ms = 0
-
-    def rotate(self, time_now_ms):
-        # Rotate once if the cache duration has passed since the last rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms += self.DURATION_MS
-
-        # Rotate again if the cache duration has passed twice since the last
-        # rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms = time_now_ms
-
-    def get(self, time_now_ms, key):
-        self.rotate(time_now_ms)
-        # This cache is intended to deduplicate requests, so we expect it to be
-        # missed most of the time. So we just lookup the key in all of the
-        # dictionaries rather than trying to short circuit the lookup if the
-        # key is found.
-        result = self.prev_result_cache.get(key)
-        result = self.next_result_cache.get(key, result)
-        result = self.pending_result_cache.get(key, result)
-        if result is not None:
-            return result.observe()
-        else:
-            return None
-
-    def set(self, time_now_ms, key, deferred):
-        self.rotate(time_now_ms)
-
-        result = ObservableDeferred(deferred)
-
-        self.pending_result_cache[key] = result
-
-        def shuffle_along(r):
-            # When the deferred completes we shuffle it along to the first
-            # generation of the result cache. So that it will eventually
-            # expire from the rotation of that cache.
-            self.next_result_cache[key] = result
-            self.pending_result_cache.pop(key, None)
-            return r
-
-        result.addBoth(shuffle_along)
-
-        return result.observe()
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 3286804322..7b18455469 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
 import logging
 from functools import wraps
 
@@ -64,12 +65,22 @@ def measure_func(name=None):
     def wrapper(func):
         block_name = func.__name__ if name is None else name
 
-        @wraps(func)
-        @defer.inlineCallbacks
-        def measured_func(self, *args, **kwargs):
-            with Measure(self.clock, block_name):
-                r = yield func(self, *args, **kwargs)
-            return r
+        if inspect.iscoroutinefunction(func):
+
+            @wraps(func)
+            async def measured_func(self, *args, **kwargs):
+                with Measure(self.clock, block_name):
+                    r = await func(self, *args, **kwargs)
+                return r
+
+        else:
+
+            @wraps(func)
+            @defer.inlineCallbacks
+            def measured_func(self, *args, **kwargs):
+                with Measure(self.clock, block_name):
+                    r = yield func(self, *args, **kwargs)
+                return r
 
         return measured_func
 
@@ -80,72 +91,48 @@ class Measure(object):
     __slots__ = [
         "clock",
         "name",
-        "start_context",
+        "_logging_context",
         "start",
-        "created_context",
-        "start_usage",
     ]
 
     def __init__(self, clock, name):
         self.clock = clock
         self.name = name
-        self.start_context = None
+        self._logging_context = None
         self.start = None
-        self.created_context = False
 
     def __enter__(self):
-        self.start = self.clock.time()
-        self.start_context = LoggingContext.current_context()
-        if not self.start_context:
-            self.start_context = LoggingContext("Measure")
-            self.start_context.__enter__()
-            self.created_context = True
-
-        self.start_usage = self.start_context.get_resource_usage()
+        if self._logging_context:
+            raise RuntimeError("Measure() objects cannot be re-used")
 
+        self.start = self.clock.time()
+        parent_context = LoggingContext.current_context()
+        self._logging_context = LoggingContext(
+            "Measure[%s]" % (self.name,), parent_context
+        )
+        self._logging_context.__enter__()
         in_flight.register((self.name,), self._update_in_flight)
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        if isinstance(exc_type, Exception) or not self.start_context:
-            return
-
-        in_flight.unregister((self.name,), self._update_in_flight)
+        if not self._logging_context:
+            raise RuntimeError("Measure() block exited without being entered")
 
         duration = self.clock.time() - self.start
+        usage = self._logging_context.get_resource_usage()
 
-        block_counter.labels(self.name).inc()
-        block_timer.labels(self.name).inc(duration)
-
-        context = LoggingContext.current_context()
-
-        if context != self.start_context:
-            logger.warning(
-                "Context has unexpectedly changed from '%s' to '%s'. (%r)",
-                self.start_context,
-                context,
-                self.name,
-            )
-            return
-
-        if not context:
-            logger.warning("Expected context. (%r)", self.name)
-            return
+        in_flight.unregister((self.name,), self._update_in_flight)
+        self._logging_context.__exit__(exc_type, exc_val, exc_tb)
 
-        current = context.get_resource_usage()
-        usage = current - self.start_usage
         try:
+            block_counter.labels(self.name).inc()
+            block_timer.labels(self.name).inc(duration)
             block_ru_utime.labels(self.name).inc(usage.ru_utime)
             block_ru_stime.labels(self.name).inc(usage.ru_stime)
             block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
             block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
             block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
         except ValueError:
-            logger.warning(
-                "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
-            )
-
-        if self.created_context:
-            self.start_context.__exit__(exc_type, exc_val, exc_tb)
+            logger.warning("Failed to save metrics! Usage: %s", usage)
 
     def _update_in_flight(self, metrics):
         """Gets called when processing in flight metrics
diff --git a/synmark/__init__.py b/synmark/__init__.py
index 570eb818d9..afe4fad8cb 100644
--- a/synmark/__init__.py
+++ b/synmark/__init__.py
@@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None):
     stor = hs.get_datastore()
 
     # Run the database background updates.
-    if hasattr(stor, "do_next_background_update"):
-        while not await stor.has_completed_background_updates():
-            await stor.do_next_background_update(1)
+    if hasattr(stor.db.updates, "do_next_background_update"):
+        while not await stor.db.updates.has_completed_background_updates():
+            await stor.db.updates.do_next_background_update(1)
 
     def cleanup():
         for i in cleanup_tasks:
diff --git a/sytest-blacklist b/sytest-blacklist
index 411cce0692..79b2d4402a 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -33,3 +33,6 @@ New federated private chats get full presence information (SYN-115)
 # Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing
 # this requirement from the spec
 Inbound federation of state requires event_id as a mandatory paramater
+
+# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
+Can upload self-signing keys
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 854eb6c024..fdfa2cbbc4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
+    test_replace_master_key.skip = (
+        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+    )
+
     @defer.inlineCallbacks
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
@@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             ],
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
+
+    test_upload_signatures.skip = (
+        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+    )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 380fd0d107..d9d312f0fb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_rooms",
@@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
@@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
     def get_all_room_state(self):
-        return self.store.simple_select_list(
+        return self.store.db.simple_select_list(
             "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
         )
 
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
 
         return self.get_success(
-            self.store.simple_select_one(
+            self.store.db.simple_select_one(
                 table + "_historical",
                 {id_col: stat_id, end_ts: end_ts},
                 cols,
@@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # Do the initial population of the stats via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
     def test_initial_room(self):
         """
@@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         r = self.get_success(self.get_all_room_state())
 
@@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # the position that the deltas should begin at, once they take over.
         self.hs.config.stats_enabled = True
         self.handler.stats_enabled = True
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
         self.get_success(
-            self.store.simple_update_one(
+            self.store.db.simple_update_one(
                 table="stats_incremental_position",
                 keyvalues={},
                 updatevalues={"stream_id": 0},
@@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         # Now, before the table is actually ingested, add some more events.
         self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
@@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # Now do the initial ingestion.
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
-        self.store._all_done = False
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        self.store.db.updates._all_done = False
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         self.reactor.advance(86401)
 
@@ -653,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # preparation stage of the initial background update
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         self.get_success(
-            self.store.simple_delete(
+            self.store.db.simple_delete(
                 "room_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
         self.get_success(
-            self.store.simple_delete(
+            self.store.db.simple_delete(
                 "user_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
@@ -673,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # now do the background updates
 
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_rooms",
@@ -685,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
@@ -695,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         r1stats_complete = self._get_current_stats("room", r1)
         u1stats_complete = self._get_current_stats("user", u1)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..758ee071a5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,53 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.internet import defer
 
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
 from synapse.types import UserID
 
 import tests.unittest
 import tests.utils
-from tests.utils import setup_test_homeserver
 
 
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
     """ Tests Sync Handler. """
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.sync_handler = SyncHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.hs = hs
+        self.sync_handler = self.hs.get_sync_handler()
         self.store = self.hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_wait_for_sync_for_user_auth_blocking(self):
 
         user_id1 = "@user1:server"
         user_id2 = "@user2:server"
         sync_config = self._generate_sync_config(user_id1)
 
+        self.reactor.advance(100)  # So we get not 0 time
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 1
 
         # Check that the happy case does not throw errors
-        yield self.store.upsert_monthly_active_user(user_id1)
-        yield self.sync_handler.wait_for_sync_for_user(sync_config)
+        self.get_success(self.store.upsert_monthly_active_user(user_id1))
+        self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
 
         # Test that global lock works
         self.hs.config.hs_disabled = True
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
         self.hs.config.hs_disabled = False
 
         sync_config = self._generate_sync_config(user_id2)
 
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     def _generate_sync_config(self, user_id):
         return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f6d8660285..92b8726093 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -163,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -227,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -279,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -300,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -317,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 2)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+        )
         self.assertEquals(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -335,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index d5b1c5b4ac..26071059d2 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_in_public_rooms(self):
         r = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 "users_in_public_rooms", None, ("user_id", "room_id")
             )
         )
@@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_who_share_private_rooms(self):
         return self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 "users_who_share_private_rooms",
                 None,
                 ["user_id", "other_user_id", "room_id"],
@@ -181,10 +181,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_createtables",
@@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_rooms",
@@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_users",
@@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_cleanup",
@@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         shares_private = self.get_users_who_share_private_rooms()
         public_users = self.get_users_in_public_rooms()
@@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         shares_private = self.get_users_who_share_private_rooms()
         public_users = self.get_users_in_public_rooms()
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index e7472e3a93..3dae83c543 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ from synapse.replication.tcp.client import (
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
 
         self.master_store = self.hs.get_datastore()
         self.storage = hs.get_storage()
-        self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
+        self.slaved_store = self.STORE_TYPE(
+            Database(hs), self.hs.get_db_conn(), self.hs
+        )
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 124ce0768a..0ed2594381 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
             "state_groups_state",
         ):
             count = self.get_success(
-                self.store.simple_select_one_onecol(
+                self.store.db.simple_select_one_onecol(
                     table=table,
                     keyvalues={"room_id": room_id},
                     retcol="COUNT(*)",
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 30fb77bac8..4bc3aaf02d 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+        events = self.get_success(
+            self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+        )
         self.assertEquals(
             events[0],
             [
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 7b7434a468..d491ea2924 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         self.table_name = "table_" + hs.get_secrets().token_hex(6)
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "create",
                 lambda x, *a: x.execute(*a),
                 "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "index",
                 lambda x, *a: x.execute(*a),
                 "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["hello"], ["there"]]
 
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "test",
-                self.storage.simple_upsert_many_txn,
+                self.storage.db.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.simple_select_list(
+            self.storage.db.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["bleb"]]
 
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "test",
-                self.storage.simple_upsert_many_txn,
+                self.storage.db.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.simple_select_list(
+            self.storage.db.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index dfeea24599..2e521e9ab7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -54,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -123,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -382,8 +385,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, db_conn, hs):
-        super(TestTransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -416,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(hs.get_db_conn(), hs)
+        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -432,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -453,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 9fabe3fbc0..aec76f4ab1 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
 
         self.update_handler = Mock()
 
-        yield self.store.register_background_update_handler(
+        yield self.store.db.updates.register_background_update_handler(
             "test_update", self.update_handler
         )
 
@@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
         # (perhaps we should run them as part of the test HS setup, since we
         # run all of the other schema setup stuff there?)
         while True:
-            res = yield self.store.do_next_background_update(1000)
+            res = yield self.store.db.updates.do_next_background_update(1000)
             if res is None:
                 break
 
@@ -37,9 +37,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
         def update(progress, count):
             self.clock.advance_time_msec(count * duration_ms)
             progress = {"my_key": progress["my_key"] + 1}
-            yield self.store.runInteraction(
+            yield self.store.db.runInteraction(
                 "update_progress",
-                self.store._background_update_progress_txn,
+                self.store.db.updates._background_update_progress_txn,
                 "test_update",
                 progress,
             )
@@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase):
 
         self.update_handler.side_effect = update
 
-        yield self.store.start_background_update("test_update", {"my_key": 1})
+        yield self.store.db.updates.start_background_update(
+            "test_update", {"my_key": 1}
+        )
 
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
+        result = yield self.store.db.updates.do_next_background_update(
+            duration_ms * desired_count
+        )
         self.assertIsNotNone(result)
         self.update_handler.assert_called_once_with(
-            {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
+            {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE
         )
 
         # second step: complete the update
         @defer.inlineCallbacks
         def update(progress, count):
-            yield self.store._end_background_update("test_update")
+            yield self.store.db.updates._end_background_update("test_update")
             return count
 
         self.update_handler.side_effect = update
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
+        result = yield self.store.db.updates.do_next_background_update(
+            duration_ms * desired_count
+        )
         self.assertIsNotNone(result)
         self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
 
         # third step: we don't expect to be called any more
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
+        result = yield self.store.db.updates.do_next_background_update(
+            duration_ms * desired_count
+        )
         self.assertIsNone(result)
         self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index de5e4a5fce..537cfe9f64 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 
 from tests import unittest
@@ -59,13 +60,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
         )
 
-        self.datastore = SQLBaseStore(None, hs)
+        self.datastore = SQLBaseStore(Database(hs), None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_insert(
+        yield self.datastore.db.simple_insert(
             table="tablename", values={"columname": "Value"}
         )
 
@@ -77,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_3cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_insert(
+        yield self.datastore.db.simple_insert(
             table="tablename",
             # Use OrderedDict() so we can assert on the SQL generated
             values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -92,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
-        value = yield self.datastore.simple_select_one_onecol(
+        value = yield self.datastore.db.simple_select_one_onecol(
             table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
         )
 
@@ -106,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.fetchone.return_value = (1, 2, 3)
 
-        ret = yield self.datastore.simple_select_one(
+        ret = yield self.datastore.db.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             retcols=["colA", "colB", "colC"],
@@ -122,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 0
         self.mock_txn.fetchone.return_value = None
 
-        ret = yield self.datastore.simple_select_one(
+        ret = yield self.datastore.db.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "Not here"},
             retcols=["colA"],
@@ -137,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
-        ret = yield self.datastore.simple_select_list(
+        ret = yield self.datastore.db.simple_select_list(
             table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
         )
 
@@ -150,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_update_one(
+        yield self.datastore.db.simple_update_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             updatevalues={"columnname": "New Value"},
@@ -165,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_4cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_update_one(
+        yield self.datastore.db.simple_update_one(
             table="tablename",
             keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
             updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -180,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_delete_one(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_delete_one(
+        yield self.datastore.db.simple_delete_one(
             table="tablename", keyvalues={"keycol": "Go away"}
         )
 
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 69dcaa63d5..029ac26454 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         """Re run the background update to clean up the extremities.
         """
         # Make sure we don't clash with in progress updates.
-        self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+        self.assertTrue(
+            self.store.db.updates._all_done, "Background updates are still ongoing"
+        )
 
         schema_path = os.path.join(
             prepare_database.dir_path,
@@ -62,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             prepare_database.executescript(txn, schema_path)
 
         self.get_success(
-            self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+            self.store.db.runInteraction(
+                "test_delete_forward_extremities", run_delta_file
+            )
         )
 
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
     def test_soft_failed_extremities_handled_correctly(self):
         """Test that extremities are correctly calculated in the presence of
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 25bdd2c163..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
 
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
@@ -81,7 +85,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -112,7 +116,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -202,25 +206,31 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
     def test_devices_last_seen_bg_update(self):
         # First make sure we have completed all updates.
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
-
         # Force persisting to disk
         self.reactor.advance(200)
 
         # But clear the associated entry in devices table
         self.get_success(
-            self.store.simple_update(
+            self.store.db.simple_update(
                 table="devices",
-                keyvalues={"user_id": user_id, "device_id": "device_id"},
+                keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
                 desc="test_devices_last_seen_bg_update",
             )
@@ -228,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get nulls when querying
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
@@ -245,7 +255,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "devices_last_seen",
@@ -256,22 +266,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         # Now let's actually drive the updates to completion
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         # We should now get the correct result again
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
@@ -281,14 +295,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
     def test_old_user_ips_pruned(self):
         # First make sure we have completed all updates.
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -297,7 +318,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should see that in the DB
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -312,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                     "access_token": "access_token",
                     "ip": "ip",
                     "user_agent": "user_agent",
-                    "device_id": "device_id",
+                    "device_id": device_id,
                     "last_seen": 0,
                 }
             ],
@@ -323,7 +344,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should get no results.
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -335,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But we should still get the correct values for the device
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2fe50377f8..eadfb90a22 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -61,7 +61,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
             )
 
         for i in range(0, 11):
-            yield self.store.runInteraction("insert", insert_event, i)
+            yield self.store.db.runInteraction("insert", insert_event, i)
 
         # this should get the last five and five others
         r = yield self.store.get_prev_events_for_room(room_id)
@@ -93,9 +93,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
             )
 
         for i in range(0, 20):
-            yield self.store.runInteraction("insert", insert_event, i, room1)
-            yield self.store.runInteraction("insert", insert_event, i, room2)
-            yield self.store.runInteraction("insert", insert_event, i, room3)
+            yield self.store.db.runInteraction("insert", insert_event, i, room1)
+            yield self.store.db.runInteraction("insert", insert_event, i, room2)
+            yield self.store.db.runInteraction("insert", insert_event, i, room3)
 
         # Test simple case
         r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 2337a1ae46..d4bcf1821e 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield self.store.runInteraction(
+            counts = yield self.store.db.runInteraction(
                 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
             )
             self.assertEquals(
@@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             yield self.store.add_push_actions_to_staging(
                 event.event_id, {user_id: action}
             )
-            yield self.store.runInteraction(
+            yield self.store.db.runInteraction(
                 "",
                 self.store._set_push_actions_for_event_and_users_txn,
                 [(event, None)],
@@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         def _rotate(stream):
-            return self.store.runInteraction(
+            return self.store.db.runInteraction(
                 "", self.store._rotate_notifs_before_txn, stream
             )
 
         def _mark_read(stream, depth):
-            return self.store.runInteraction(
+            return self.store.db.runInteraction(
                 "",
                 self.store._remove_old_push_actions_before_txn,
                 room_id,
@@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _inject_actions(6, PlAIN_NOTIF)
         yield _rotate(7)
 
-        yield self.store.simple_delete(
+        yield self.store.db.simple_delete(
             table="event_push_actions", keyvalues={"1": 1}, desc=""
         )
 
@@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store.simple_insert(
+            return self.store.db.simple_insert(
                 "events",
                 {
                     "stream_ordering": so,
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 90a63dc477..3c78faab45 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -65,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.user_add_threepid(user1, "email", user1_email, now, now)
         self.store.user_add_threepid(user2, "email", user2_email, now, now)
 
-        self.store.runInteraction(
+        self.store.db.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.pump()
@@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
                 )
 
         self.hs.config.mau_limits_reserved_threepids = threepids
-        self.store.runInteraction(
+        self.store.db.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         count = self.store.get_monthly_active_count()
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
             {"medium": "email", "address": user2_email},
         ]
         self.hs.config.mau_limits_reserved_threepids = threepids
-        self.store.runInteraction(
+        self.store.db.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24c7fe16c3..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.profile import ProfileStore
 from synapse.types import UserID
 
 from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = ProfileStore(hs.get_db_conn(), hs)
+        self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 4930b6777e..dc45173355 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         event_json = self.get_success(
-            self.store.simple_select_one_onecol(
+            self.store.db.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
@@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(60 * 60 * 2)
 
         event_json = self.get_success(
-            self.store.simple_select_one_onecol(
+            self.store.db.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index d389cf578f..7840f63fe3 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -122,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
     def test_can_rerun_update(self):
         # First make sure we have completed all updates.
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         # Now let's create a room, which will insert a membership
         user = UserID("alice", "test")
@@ -132,7 +136,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "current_state_events_membership",
@@ -143,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         # Now let's actually drive the updates to completion
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7eea57c0e2..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+        self.store = self.hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 7d82b58466..ad165d7295 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase):
         self.reactor.advance(0.1)
         self.room_id = self.successResultOf(room)["room_id"]
 
+        self.store = self.homeserver.get_datastore()
+
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
             maybeDeferred(
@@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase):
         # Make sure we actually joined the room
         self.assertEqual(
             self.successResultOf(
-                maybeDeferred(
-                    self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                    self.room_id,
-                )
+                maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
             )[0],
             "$join:test.serv",
         )
@@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            maybeDeferred(
-                self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                self.room_id,
-            )
+            maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         )[0]
 
         # Now lie about an event
@@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = maybeDeferred(
-            self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
-        )
+        extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/unittest.py b/tests/unittest.py
index 295573bc46..b30b7d1718 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,7 @@
 import gc
 import hashlib
 import hmac
+import inspect
 import logging
 import time
 
@@ -25,7 +26,7 @@ from mock import Mock
 
 from canonicaljson import json
 
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 
@@ -401,10 +402,12 @@ class HomeserverTestCase(TestCase):
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
         stor = hs.get_datastore()
 
-        # Run the database background updates.
-        if hasattr(stor, "do_next_background_update"):
-            while not self.get_success(stor.has_completed_background_updates()):
-                self.get_success(stor.do_next_background_update(1))
+        # Run the database background updates, when running against "master".
+        if hs.__class__.__name__ == "TestHomeServer":
+            while not self.get_success(
+                stor.db.updates.has_completed_background_updates()
+            ):
+                self.get_success(stor.db.updates.do_next_background_update(1))
 
         return hs
 
@@ -415,6 +418,8 @@ class HomeserverTestCase(TestCase):
         self.reactor.pump([by] * 100)
 
     def get_success(self, d, by=0.0):
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump(by=by)
@@ -424,6 +429,8 @@ class HomeserverTestCase(TestCase):
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump()
@@ -544,7 +551,7 @@ class HomeserverTestCase(TestCase):
         Add the given event as an extremity to the room.
         """
         self.get_success(
-            self.hs.get_datastore().simple_insert(
+            self.hs.get_datastore().db.simple_insert(
                 table="event_forward_extremities",
                 values={"room_id": room_id, "event_id": event_id},
                 desc="test_add_extremity",
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.request, "foo-bar")
 
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_await(self):
+        # an async function which retuns an incomplete coroutine, but doesn't
+        # follow the synapse rules.
+
+        async def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            await d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
-    def setUp(self):
-        self.cache = SnapshotCache()
-        self.cache.DURATION_MS = 1
-
-    def test_get_set(self):
-        # Check that getting a missing key returns None
-        self.assertEquals(self.cache.get(0, "key"), None)
-
-        # Check that setting a key with a deferred returns
-        # a deferred that resolves when the initial deferred does
-        d = Deferred()
-        set_result = self.cache.set(0, "key", d)
-        self.assertIsNotNone(set_result)
-        self.assertFalse(set_result.called)
-
-        # Check that getting the key before the deferred has resolved
-        # returns a deferred that resolves when the initial deferred does.
-        get_result_at_10 = self.cache.get(10, "key")
-        self.assertIsNotNone(get_result_at_10)
-        self.assertFalse(get_result_at_10.called)
-
-        # Check that the returned deferreds resolve when the initial deferred
-        # does.
-        d.callback("v")
-        self.assertTrue(set_result.called)
-        self.assertTrue(get_result_at_10.called)
-
-        # Check that getting the key after the deferred has resolved
-        # before the cache expires returns a resolved deferred.
-        get_result_at_11 = self.cache.get(11, "key")
-        self.assertIsNotNone(get_result_at_11)
-        if isinstance(get_result_at_11, Deferred):
-            # The cache may return the actual result rather than a deferred
-            self.assertTrue(get_result_at_11.called)
-
-        # Check that getting the key after the deferred has resolved
-        # after the cache expires returns None
-        get_result_at_12 = self.cache.get(12, "key")
-        self.assertIsNone(get_result_at_12)