summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py552
-rw-r--r--synapse/storage/_base.py1639
-rw-r--r--synapse/storage/background_updates.py187
-rw-r--r--synapse/storage/data_stores/__init__.py97
-rw-r--r--synapse/storage/data_stores/main/__init__.py592
-rw-r--r--synapse/storage/data_stores/main/account_data.py (renamed from synapse/storage/account_data.py)121
-rw-r--r--synapse/storage/data_stores/main/appservice.py (renamed from synapse/storage/appservice.py)52
-rw-r--r--synapse/storage/data_stores/main/cache.py280
-rw-r--r--synapse/storage/data_stores/main/censor_events.py208
-rw-r--r--synapse/storage/data_stores/main/client_ips.py (renamed from synapse/storage/client_ips.py)302
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py (renamed from synapse/storage/deviceinbox.py)165
-rw-r--r--synapse/storage/data_stores/main/devices.py (renamed from synapse/storage/devices.py)775
-rw-r--r--synapse/storage/data_stores/main/directory.py (renamed from synapse/storage/directory.py)49
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py (renamed from synapse/storage/e2e_room_keys.py)256
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py721
-rw-r--r--synapse/storage/data_stores/main/event_federation.py (renamed from synapse/storage/event_federation.py)428
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py (renamed from synapse/storage/event_push_actions.py)163
-rw-r--r--synapse/storage/data_stores/main/events.py1620
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py (renamed from synapse/storage/events_bg_updates.py)246
-rw-r--r--synapse/storage/data_stores/main/events_worker.py (renamed from synapse/storage/events_worker.py)711
-rw-r--r--synapse/storage/data_stores/main/filtering.py (renamed from synapse/storage/filtering.py)13
-rw-r--r--synapse/storage/data_stores/main/group_server.py (renamed from synapse/storage/group_server.py)920
-rw-r--r--synapse/storage/data_stores/main/keys.py208
-rw-r--r--synapse/storage/data_stores/main/media_repository.py (renamed from synapse/storage/media_repository.py)76
-rw-r--r--synapse/storage/data_stores/main/metrics.py128
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py (renamed from synapse/storage/monthly_active_users.py)325
-rw-r--r--synapse/storage/data_stores/main/openid.py (renamed from synapse/storage/openid.py)8
-rw-r--r--synapse/storage/data_stores/main/presence.py153
-rw-r--r--synapse/storage/data_stores/main/profile.py (renamed from synapse/storage/profile.py)35
-rw-r--r--synapse/storage/data_stores/main/purge_events.py399
-rw-r--r--synapse/storage/data_stores/main/push_rule.py729
-rw-r--r--synapse/storage/data_stores/main/pusher.py (renamed from synapse/storage/pusher.py)214
-rw-r--r--synapse/storage/data_stores/main/receipts.py (renamed from synapse/storage/receipts.py)111
-rw-r--r--synapse/storage/data_stores/main/registration.py (renamed from synapse/storage/registration.py)530
-rw-r--r--synapse/storage/data_stores/main/rejections.py (renamed from synapse/storage/rejections.py)15
-rw-r--r--synapse/storage/data_stores/main/relations.py327
-rw-r--r--synapse/storage/data_stores/main/room.py1433
-rw-r--r--synapse/storage/data_stores/main/roommember.py1138
-rw-r--r--synapse/storage/data_stores/main/schema/delta/12/v12.sql (renamed from synapse/storage/schema/delta/12/v12.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/13/v13.sql (renamed from synapse/storage/schema/delta/13/v13.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/14/v14.sql (renamed from synapse/storage/schema/delta/14/v14.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql (renamed from synapse/storage/schema/delta/15/appservice_txns.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql (renamed from synapse/storage/schema/delta/15/presence_indices.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/15/v15.sql (renamed from synapse/storage/schema/delta/15/v15.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql (renamed from synapse/storage/schema/delta/16/events_order_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql (renamed from synapse/storage/schema/delta/16/remote_media_cache_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql (renamed from synapse/storage/schema/delta/16/remove_duplicates.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql (renamed from synapse/storage/schema/delta/16/room_alias_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql (renamed from synapse/storage/schema/delta/16/unique_constraints.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/16/users.sql (renamed from synapse/storage/schema/delta/16/users.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql (renamed from synapse/storage/schema/delta/17/drop_indexes.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/17/server_keys.sql (renamed from synapse/storage/schema/delta/17/server_keys.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql (renamed from synapse/storage/schema/delta/17/user_threepids.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql (renamed from synapse/storage/schema/delta/18/server_keys_bigger_ints.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/19/event_index.sql (renamed from synapse/storage/schema/delta/19/event_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/20/dummy.sql (renamed from synapse/storage/schema/delta/20/dummy.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/20/pushers.py (renamed from synapse/storage/schema/delta/20/pushers.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql (renamed from synapse/storage/schema/delta/21/end_to_end_keys.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/21/receipts.sql (renamed from synapse/storage/schema/delta/21/receipts.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql (renamed from synapse/storage/schema/delta/22/receipts_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql (renamed from synapse/storage/schema/delta/22/user_threepids_unique.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql (renamed from synapse/storage/schema/delta/24/stats_reporting.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/25/fts.py (renamed from synapse/storage/schema/delta/25/fts.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/25/guest_access.sql (renamed from synapse/storage/schema/delta/25/guest_access.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql (renamed from synapse/storage/schema/delta/25/history_visibility.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/25/tags.sql (renamed from synapse/storage/schema/delta/25/tags.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/26/account_data.sql (renamed from synapse/storage/schema/delta/26/account_data.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/27/account_data.sql (renamed from synapse/storage/schema/delta/27/account_data.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql (renamed from synapse/storage/schema/delta/27/forgotten_memberships.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/27/ts.py (renamed from synapse/storage/schema/delta/27/ts.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql (renamed from synapse/storage/schema/delta/28/event_push_actions.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql (renamed from synapse/storage/schema/delta/28/events_room_stream.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql (renamed from synapse/storage/schema/delta/28/public_roms_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql (renamed from synapse/storage/schema/delta/28/receipts_user_id_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql (renamed from synapse/storage/schema/delta/28/upgrade_times.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql (renamed from synapse/storage/schema/delta/28/users_is_guest.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/29/push_actions.sql (renamed from synapse/storage/schema/delta/29/push_actions.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql (renamed from synapse/storage/schema/delta/30/alias_creator.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/30/as_users.py (renamed from synapse/storage/schema/delta/30/as_users.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql (renamed from synapse/storage/schema/delta/30/deleted_pushers.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql (renamed from synapse/storage/schema/delta/30/presence_stream.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql (renamed from synapse/storage/schema/delta/30/public_rooms.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql (renamed from synapse/storage/schema/delta/30/push_rule_stream.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql (renamed from synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/31/invites.sql (renamed from synapse/storage/schema/delta/31/invites.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql (renamed from synapse/storage/schema/delta/31/local_media_repository_url_cache.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/31/pushers.py (renamed from synapse/storage/schema/delta/31/pushers.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql (renamed from synapse/storage/schema/delta/31/pushers_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/31/search_update.py (renamed from synapse/storage/schema/delta/31/search_update.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/32/events.sql (renamed from synapse/storage/schema/delta/32/events.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/32/openid.sql (renamed from synapse/storage/schema/delta/32/openid.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql (renamed from synapse/storage/schema/delta/32/pusher_throttle.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql (renamed from synapse/storage/schema/delta/32/remove_indices.sql)1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/32/reports.sql (renamed from synapse/storage/schema/delta/32/reports.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql (renamed from synapse/storage/schema/delta/33/access_tokens_device_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/33/devices.sql (renamed from synapse/storage/schema/delta/33/devices.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql (renamed from synapse/storage/schema/delta/33/devices_for_e2e_keys.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql (renamed from synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/33/event_fields.py (renamed from synapse/storage/schema/delta/33/event_fields.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py (renamed from synapse/storage/schema/delta/33/remote_media_ts.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql (renamed from synapse/storage/schema/delta/33/user_ips_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql (renamed from synapse/storage/schema/delta/34/appservice_stream.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/34/cache_stream.py (renamed from synapse/storage/schema/delta/34/cache_stream.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql (renamed from synapse/storage/schema/delta/34/device_inbox.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql (renamed from synapse/storage/schema/delta/34/push_display_name_rename.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py (renamed from synapse/storage/schema/delta/34/received_txn_purge.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/35/contains_url.sql (renamed from synapse/storage/schema/delta/35/contains_url.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql (renamed from synapse/storage/schema/delta/35/device_outbox.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql (renamed from synapse/storage/schema/delta/35/device_stream_id.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql (renamed from synapse/storage/schema/delta/35/event_push_actions_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql (renamed from synapse/storage/schema/delta/35/public_room_list_change_stream.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql (renamed from synapse/storage/schema/delta/35/stream_order_to_extrem.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql (renamed from synapse/storage/schema/delta/36/readd_public_rooms.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py (renamed from synapse/storage/schema/delta/37/remove_auth_idx.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql (renamed from synapse/storage/schema/delta/37/user_threepids.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql (renamed from synapse/storage/schema/delta/38/postgres_fts_gist.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql (renamed from synapse/storage/schema/delta/39/appservice_room_list.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql (renamed from synapse/storage/schema/delta/39/device_federation_stream_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql (renamed from synapse/storage/schema/delta/39/event_push_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql (renamed from synapse/storage/schema/delta/39/federation_out_position.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql (renamed from synapse/storage/schema/delta/39/membership_profile.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql (renamed from synapse/storage/schema/delta/40/current_state_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql (renamed from synapse/storage/schema/delta/40/device_inbox.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql (renamed from synapse/storage/schema/delta/40/device_list_streams.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql (renamed from synapse/storage/schema/delta/40/event_push_summary.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/40/pushers.sql (renamed from synapse/storage/schema/delta/40/pushers.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql (renamed from synapse/storage/schema/delta/41/device_list_stream_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql (renamed from synapse/storage/schema/delta/41/device_outbound_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql (renamed from synapse/storage/schema/delta/41/event_search_event_id_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql (renamed from synapse/storage/schema/delta/41/ratelimit.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql (renamed from synapse/storage/schema/delta/42/current_state_delta.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql (renamed from synapse/storage/schema/delta/42/device_list_last_id.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql (renamed from synapse/storage/schema/delta/42/event_auth_state_only.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/42/user_dir.py (renamed from synapse/storage/schema/delta/42/user_dir.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql (renamed from synapse/storage/schema/delta/43/blocked_rooms.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql (renamed from synapse/storage/schema/delta/43/quarantine_media.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/43/url_cache.sql (renamed from synapse/storage/schema/delta/43/url_cache.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/43/user_share.sql (renamed from synapse/storage/schema/delta/43/user_share.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql (renamed from synapse/storage/schema/delta/44/expire_url_cache.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/45/group_server.sql (renamed from synapse/storage/schema/delta/45/group_server.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql (renamed from synapse/storage/schema/delta/45/profile_cache.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql (renamed from synapse/storage/schema/delta/46/drop_refresh_tokens.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql (renamed from synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/46/group_server.sql (renamed from synapse/storage/schema/delta/46/group_server.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql (renamed from synapse/storage/schema/delta/46/local_media_repository_url_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql (renamed from synapse/storage/schema/delta/46/user_dir_null_room_ids.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql (renamed from synapse/storage/schema/delta/46/user_dir_typos.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql (renamed from synapse/storage/schema/delta/47/last_access_media.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql (renamed from synapse/storage/schema/delta/47/postgres_fts_gin.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql (renamed from synapse/storage/schema/delta/47/push_actions_staging.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql (renamed from synapse/storage/schema/delta/48/add_user_consent.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql (renamed from synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql (renamed from synapse/storage/schema/delta/48/deactivated_users.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py (renamed from synapse/storage/schema/delta/48/group_unique_indexes.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql (renamed from synapse/storage/schema/delta/48/groups_joinable.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql (renamed from synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql (renamed from synapse/storage/schema/delta/49/add_user_daily_visits.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql (renamed from synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql (renamed from synapse/storage/schema/delta/50/add_creation_ts_users_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql (renamed from synapse/storage/schema/delta/50/erasure_store.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py (renamed from synapse/storage/schema/delta/50/make_event_content_nullable.py)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql (renamed from synapse/storage/schema/delta/51/e2e_room_keys.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql (renamed from synapse/storage/schema/delta/51/monthly_active_users.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql (renamed from synapse/storage/schema/delta/52/add_event_to_state_group_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql (renamed from synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql (renamed from synapse/storage/schema/delta/52/e2e_room_keys.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql (renamed from synapse/storage/schema/delta/53/add_user_type_to_users.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql (renamed from synapse/storage/schema/delta/53/drop_sent_transactions.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql (renamed from synapse/storage/schema/delta/53/event_format_version.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql (renamed from synapse/storage/schema/delta/53/user_dir_populate.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql (renamed from synapse/storage/schema/delta/53/user_ips_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/user_share.sql (renamed from synapse/storage/schema/delta/53/user_share.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql (renamed from synapse/storage/schema/delta/53/user_threepid_id.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql (renamed from synapse/storage/schema/delta/53/users_in_public_rooms.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql (renamed from synapse/storage/schema/delta/54/account_validity_with_renewal.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql (renamed from synapse/storage/schema/delta/54/add_validity_to_server_keys.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql (renamed from synapse/storage/schema/delta/54/delete_forward_extremities.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql (renamed from synapse/storage/schema/delta/54/drop_legacy_tables.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql (renamed from synapse/storage/schema/delta/54/drop_presence_list.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/relations.sql (renamed from synapse/storage/schema/delta/54/relations.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/stats.sql (renamed from synapse/storage/schema/delta/54/stats.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/54/stats2.sql (renamed from synapse/storage/schema/delta/54/stats2.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql (renamed from synapse/storage/schema/delta/55/access_token_expiry.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql (renamed from synapse/storage/schema/delta/55/track_threepid_validations.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql (renamed from synapse/storage/schema/delta/55/users_alter_deactivated.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql (renamed from synapse/storage/schema/delta/56/add_spans_to_device_lists.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql (renamed from synapse/storage/schema/delta/56/current_state_events_membership.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql (renamed from synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql25
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql (renamed from synapse/storage/schema/delta/56/destinations_failure_ts.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres18
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql20
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql24
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql (renamed from synapse/storage/schema/delta/56/drop_unused_event_tables.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql21
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_labels.sql30
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql17
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql (renamed from synapse/storage/schema/delta/56/fix_room_keys_index.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql18
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite42
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql29
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql16
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql (renamed from synapse/storage/schema/delta/56/redaction_censor.sql)1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres25
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql16
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql18
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql17
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql (renamed from synapse/storage/schema/delta/56/room_membership_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/room_retention.sql33
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql56
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql (renamed from synapse/storage/schema/delta/56/stats_separated.sql)6
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py52
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql24
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql (renamed from synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql25
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py98
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql21
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql24
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres35
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres39
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite23
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql22
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql36
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres30
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py80
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql (renamed from synapse/storage/schema/full_schemas/16/application_services.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql (renamed from synapse/storage/schema/full_schemas/16/event_edges.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql (renamed from synapse/storage/schema/full_schemas/16/event_signatures.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/im.sql (renamed from synapse/storage/schema/full_schemas/16/im.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql (renamed from synapse/storage/schema/full_schemas/16/keys.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql (renamed from synapse/storage/schema/full_schemas/16/media_repository.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql (renamed from synapse/storage/schema/full_schemas/16/presence.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql (renamed from synapse/storage/schema/full_schemas/16/profiles.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/push.sql (renamed from synapse/storage/schema/full_schemas/16/push.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql (renamed from synapse/storage/schema/full_schemas/16/redactions.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql (renamed from synapse/storage/schema/full_schemas/16/room_aliases.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/state.sql (renamed from synapse/storage/schema/full_schemas/16/state.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql (renamed from synapse/storage/schema/full_schemas/16/transactions.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/16/users.sql (renamed from synapse/storage/schema/full_schemas/16/users.sql)0
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres (renamed from synapse/storage/schema/full_schemas/54/full.sql.postgres)69
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite (renamed from synapse/storage/schema/full_schemas/54/full.sql.sqlite)7
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql (renamed from synapse/storage/schema/full_schemas/54/stream_positions.sql)1
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.md21
-rw-r--r--synapse/storage/data_stores/main/search.py (renamed from synapse/storage/search.py)219
-rw-r--r--synapse/storage/data_stores/main/signatures.py (renamed from synapse/storage/signatures.py)35
-rw-r--r--synapse/storage/data_stores/main/state.py467
-rw-r--r--synapse/storage/data_stores/main/state_deltas.py (renamed from synapse/storage/state_deltas.py)46
-rw-r--r--synapse/storage/data_stores/main/stats.py (renamed from synapse/storage/stats.py)113
-rw-r--r--synapse/storage/data_stores/main/stream.py (renamed from synapse/storage/stream.py)107
-rw-r--r--synapse/storage/data_stores/main/tags.py (renamed from synapse/storage/tags.py)27
-rw-r--r--synapse/storage/data_stores/main/transactions.py (renamed from synapse/storage/transactions.py)39
-rw-r--r--synapse/storage/data_stores/main/ui_auth.py300
-rw-r--r--synapse/storage/data_stores/main/user_directory.py (renamed from synapse/storage/user_directory.py)278
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py (renamed from synapse/storage/user_erasure_store.py)24
-rw-r--r--synapse/storage/data_stores/state/__init__.py16
-rw-r--r--synapse/storage/data_stores/state/bg_updates.py374
-rw-r--r--synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql (renamed from synapse/storage/schema/delta/23/drop_state_index.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/30/state_stream.sql (renamed from synapse/storage/schema/delta/30/state_stream.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql19
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql (renamed from synapse/storage/schema/delta/35/add_state_index.sql)3
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/state.sql (renamed from synapse/storage/schema/delta/35/state.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql (renamed from synapse/storage/schema/delta/35/state_dedupe.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py (renamed from synapse/storage/schema/delta/47/state_group_seq.py)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql17
-rw-r--r--synapse/storage/data_stores/state/schema/full_schemas/54/full.sql37
-rw-r--r--synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres21
-rw-r--r--synapse/storage/data_stores/state/store.py642
-rw-r--r--synapse/storage/database.py1644
-rw-r--r--synapse/storage/end_to_end_keys.py313
-rw-r--r--synapse/storage/engines/__init__.py28
-rw-r--r--synapse/storage/engines/_base.py87
-rw-r--r--synapse/storage/engines/postgres.py108
-rw-r--r--synapse/storage/engines/sqlite.py55
-rw-r--r--synapse/storage/events.py2470
-rw-r--r--synapse/storage/keys.py194
-rw-r--r--synapse/storage/persist_events.py794
-rw-r--r--synapse/storage/prepare_database.py235
-rw-r--r--synapse/storage/presence.py136
-rw-r--r--synapse/storage/purge_events.py117
-rw-r--r--synapse/storage/push_rule.py702
-rw-r--r--synapse/storage/relations.py359
-rw-r--r--synapse/storage/room.py585
-rw-r--r--synapse/storage/roommember.py1072
-rw-r--r--synapse/storage/schema/delta/35/00background_updates_add_col.sql17
-rw-r--r--synapse/storage/schema/delta/58/00background_update_ordering.sql19
-rw-r--r--synapse/storage/schema/full_schemas/54/full.sql8
-rw-r--r--synapse/storage/schema/full_schemas/README.txt19
-rw-r--r--synapse/storage/state.py1049
-rw-r--r--synapse/storage/types.py65
-rw-r--r--synapse/storage/util/id_generators.py182
294 files changed, 18475 insertions, 11621 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e7f6ea7286..ec89f645d4 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018,2019 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,518 +14,38 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import calendar
-import logging
-import time
-
-from twisted.internet import defer
-
-from synapse.api.constants import PresenceState
-from synapse.storage.devices import DeviceStore
-from synapse.storage.user_erasure_store import UserErasureStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from .account_data import AccountDataStore
-from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .client_ips import ClientIpStore
-from .deviceinbox import DeviceInboxStore
-from .directory import DirectoryStore
-from .e2e_room_keys import EndToEndRoomKeyStore
-from .end_to_end_keys import EndToEndKeyStore
-from .engines import PostgresEngine
-from .event_federation import EventFederationStore
-from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
-from .events_bg_updates import EventsBackgroundUpdatesStore
-from .filtering import FilteringStore
-from .group_server import GroupServerStore
-from .keys import KeyStore
-from .media_repository import MediaRepositoryStore
-from .monthly_active_users import MonthlyActiveUsersStore
-from .openid import OpenIdStore
-from .presence import PresenceStore, UserPresenceState
-from .profile import ProfileStore
-from .push_rule import PushRuleStore
-from .pusher import PusherStore
-from .receipts import ReceiptsStore
-from .registration import RegistrationStore
-from .rejections import RejectionsStore
-from .relations import RelationsStore
-from .room import RoomStore
-from .roommember import RoomMemberStore
-from .search import SearchStore
-from .signatures import SignatureStore
-from .state import StateStore
-from .stats import StatsStore
-from .stream import StreamStore
-from .tags import TagsStore
-from .transactions import TransactionStore
-from .user_directory import UserDirectoryStore
-from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
-
-logger = logging.getLogger(__name__)
-
-
-class DataStore(
-    EventsBackgroundUpdatesStore,
-    RoomMemberStore,
-    RoomStore,
-    RegistrationStore,
-    StreamStore,
-    ProfileStore,
-    PresenceStore,
-    TransactionStore,
-    DirectoryStore,
-    KeyStore,
-    StateStore,
-    SignatureStore,
-    ApplicationServiceStore,
-    EventsStore,
-    EventFederationStore,
-    MediaRepositoryStore,
-    RejectionsStore,
-    FilteringStore,
-    PusherStore,
-    PushRuleStore,
-    ApplicationServiceTransactionStore,
-    ReceiptsStore,
-    EndToEndKeyStore,
-    EndToEndRoomKeyStore,
-    SearchStore,
-    TagsStore,
-    AccountDataStore,
-    EventPushActionsStore,
-    OpenIdStore,
-    ClientIpStore,
-    DeviceStore,
-    DeviceInboxStore,
-    UserDirectoryStore,
-    GroupServerStore,
-    UserErasureStore,
-    MonthlyActiveUsersStore,
-    StatsStore,
-    RelationsStore,
-):
-    def __init__(self, db_conn, hs):
-        self.hs = hs
-        self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
-
-        self._stream_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            extra_tables=[("local_invites", "stream_id")],
-        )
-        self._backfill_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            step=-1,
-            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-        )
-        self._presence_id_gen = StreamIdGenerator(
-            db_conn, "presence_stream", "stream_id"
-        )
-        self._device_inbox_id_gen = StreamIdGenerator(
-            db_conn, "device_max_stream_id", "stream_id"
-        )
-        self._public_room_id_gen = StreamIdGenerator(
-            db_conn, "public_room_list_stream", "stream_id"
-        )
-        self._device_list_id_gen = StreamIdGenerator(
-            db_conn, "device_lists_stream", "stream_id"
-        )
-
-        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
-        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
-        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
-        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-        self._push_rules_stream_id_gen = ChainedIdGenerator(
-            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-        )
-        self._pushers_id_gen = StreamIdGenerator(
-            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
-        )
-        self._group_updates_id_gen = StreamIdGenerator(
-            db_conn, "local_group_updates", "stream_id"
-        )
-
-        if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = StreamIdGenerator(
-                db_conn, "cache_invalidation_stream", "stream_id"
-            )
-        else:
-            self._cache_id_gen = None
-
-        self._presence_on_startup = self._get_active_presence(db_conn)
-
-        presence_cache_prefill, min_presence_val = self._get_cache_dict(
-            db_conn,
-            "presence_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._presence_id_gen.get_current_token(),
-        )
-        self.presence_stream_cache = StreamChangeCache(
-            "PresenceStreamChangeCache",
-            min_presence_val,
-            prefilled_cache=presence_cache_prefill,
-        )
-
-        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
-        device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
-            db_conn,
-            "device_inbox",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            min_device_inbox_id,
-            prefilled_cache=device_inbox_prefill,
-        )
-        # The federation outbox and the local device inbox uses the same
-        # stream_id generator.
-        device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
-            db_conn,
-            "device_federation_outbox",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            min_device_outbox_id,
-            prefilled_cache=device_outbox_prefill,
-        )
-
-        device_list_max = self._device_list_id_gen.get_current_token()
-        self._device_list_stream_cache = StreamChangeCache(
-            "DeviceListStreamChangeCache", device_list_max
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache", device_list_max
-        )
-
-        events_max = self._stream_id_gen.get_current_token()
-        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
-            db_conn,
-            "current_state_delta_stream",
-            entity_column="room_id",
-            stream_column="stream_id",
-            max_value=events_max,  # As we share the stream id with events token
-            limit=1000,
-        )
-        self._curr_state_delta_stream_cache = StreamChangeCache(
-            "_curr_state_delta_stream_cache",
-            min_curr_state_delta_id,
-            prefilled_cache=curr_state_delta_prefill,
-        )
-
-        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
-            db_conn,
-            "local_group_updates",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._group_updates_id_gen.get_current_token(),
-            limit=1000,
-        )
-        self._group_updates_stream_cache = StreamChangeCache(
-            "_group_updates_stream_cache",
-            min_group_updates_id,
-            prefilled_cache=_group_updates_prefill,
-        )
-
-        self._stream_order_on_start = self.get_room_max_stream_ordering()
-        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
-
-        # Used in _generate_user_daily_visits to keep track of progress
-        self._last_user_visit_update = self._get_start_of_day()
-
-        super(DataStore, self).__init__(db_conn, hs)
-
-    def take_presence_startup_info(self):
-        active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
-        return active_on_startup
-
-    def _get_active_presence(self, db_conn):
-        """Fetch non-offline presence from the database so that we can register
-        the appropriate time outs.
-        """
-
-        sql = (
-            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
-            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
-            " WHERE state != ?"
-        )
-        sql = self.database_engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
-        txn.execute(sql, (PresenceState.OFFLINE,))
-        rows = self.cursor_to_dict(txn)
-        txn.close()
-
-        for row in rows:
-            row["currently_active"] = bool(row["currently_active"])
-
-        return [UserPresenceState(**row) for row in rows]
-
-    def count_daily_users(self):
-        """
-        Counts the number of users who used this homeserver in the last 24 hours.
-        """
-        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return self.runInteraction("count_daily_users", self._count_users, yesterday)
-
-    def count_monthly_users(self):
-        """
-        Counts the number of users who used this homeserver in the last 30 days.
-        Note this method is intended for phonehome metrics only and is different
-        from the mau figure in synapse.storage.monthly_active_users which,
-        amongst other things, includes a 3 day grace period before a user counts.
-        """
-        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return self.runInteraction(
-            "count_monthly_users", self._count_users, thirty_days_ago
-        )
-
-    def _count_users(self, txn, time_from):
-        """
-        Returns number of users seen in the past time_from period
-        """
-        sql = """
-            SELECT COALESCE(count(*), 0) FROM (
-                SELECT user_id FROM user_ips
-                WHERE last_seen > ?
-                GROUP BY user_id
-            ) u
-        """
-        txn.execute(sql, (time_from,))
-        count, = txn.fetchone()
-        return count
-
-    def count_r30_users(self):
-        """
-        Counts the number of 30 day retained users, defined as:-
-         * Users who have created their accounts more than 30 days ago
-         * Where last seen at most 30 days ago
-         * Where account creation and last_seen are > 30 days apart
-
-         Returns counts globaly for a given user as well as breaking
-         by platform
-        """
-
-        def _count_r30_users(txn):
-            thirty_days_in_secs = 86400 * 30
-            now = int(self._clock.time())
-            thirty_days_ago_in_secs = now - thirty_days_in_secs
-
-            sql = """
-                SELECT platform, COALESCE(count(*), 0) FROM (
-                     SELECT
-                        users.name, platform, users.creation_ts * 1000,
-                        MAX(uip.last_seen)
-                     FROM users
-                     INNER JOIN (
-                         SELECT
-                         user_id,
-                         last_seen,
-                         CASE
-                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
-                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
-                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
-                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
-                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
-                             ELSE 'unknown'
-                         END
-                         AS platform
-                         FROM user_ips
-                     ) uip
-                     ON users.name = uip.user_id
-                     AND users.appservice_id is NULL
-                     AND users.creation_ts < ?
-                     AND uip.last_seen/1000 > ?
-                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                     GROUP BY users.name, platform, users.creation_ts
-                ) u GROUP BY platform
-            """
-
-            results = {}
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            for row in txn:
-                if row[0] == "unknown":
-                    pass
-                results[row[0]] = row[1]
-
-            sql = """
-                SELECT COALESCE(count(*), 0) FROM (
-                    SELECT users.name, users.creation_ts * 1000,
-                                                        MAX(uip.last_seen)
-                    FROM users
-                    INNER JOIN (
-                        SELECT
-                        user_id,
-                        last_seen
-                        FROM user_ips
-                    ) uip
-                    ON users.name = uip.user_id
-                    AND appservice_id is NULL
-                    AND users.creation_ts < ?
-                    AND uip.last_seen/1000 > ?
-                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                    GROUP BY users.name, users.creation_ts
-                ) u
-            """
-
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            count, = txn.fetchone()
-            results["all"] = count
-
-            return results
-
-        return self.runInteraction("count_r30_users", _count_r30_users)
-
-    def _get_start_of_day(self):
-        """
-        Returns millisecond unixtime for start of UTC day.
-        """
-        now = time.gmtime()
-        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
-        return today_start * 1000
-
-    def generate_user_daily_visits(self):
-        """
-        Generates daily visit data for use in cohort/ retention analysis
-        """
-
-        def _generate_user_daily_visits(txn):
-            logger.info("Calling _generate_user_daily_visits")
-            today_start = self._get_start_of_day()
-            a_day_in_milliseconds = 24 * 60 * 60 * 1000
-            now = self.clock.time_msec()
-
-            sql = """
-                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
-                    SELECT u.user_id, u.device_id, ?
-                    FROM user_ips AS u
-                    LEFT JOIN (
-                      SELECT user_id, device_id, timestamp FROM user_daily_visits
-                      WHERE timestamp = ?
-                    ) udv
-                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
-                    INNER JOIN users ON users.name=u.user_id
-                    WHERE last_seen > ? AND last_seen <= ?
-                    AND udv.timestamp IS NULL AND users.is_guest=0
-                    AND users.appservice_id IS NULL
-                    GROUP BY u.user_id, u.device_id
-            """
-
-            # This means that the day has rolled over but there could still
-            # be entries from the previous day. There is an edge case
-            # where if the user logs in at 23:59 and overwrites their
-            # last_seen at 00:01 then they will not be counted in the
-            # previous day's stats - it is important that the query is run
-            # often to minimise this case.
-            if today_start > self._last_user_visit_update:
-                yesterday_start = today_start - a_day_in_milliseconds
-                txn.execute(
-                    sql,
-                    (
-                        yesterday_start,
-                        yesterday_start,
-                        self._last_user_visit_update,
-                        today_start,
-                    ),
-                )
-                self._last_user_visit_update = today_start
-
-            txn.execute(
-                sql, (today_start, today_start, self._last_user_visit_update, now)
-            )
-            # Update _last_user_visit_update to now. The reason to do this
-            # rather just clamping to the beginning of the day is to limit
-            # the size of the join - meaning that the query can be run more
-            # frequently
-            self._last_user_visit_update = now
-
-        return self.runInteraction(
-            "generate_user_daily_visits", _generate_user_daily_visits
-        )
-
-    def get_users(self):
-        """Function to reterive a list of users in users table.
-
-        Args:
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self._simple_select_list(
-            table="users",
-            keyvalues={},
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
-            desc="get_users",
-        )
-
-    @defer.inlineCallbacks
-    def get_users_paginate(self, order, start, limit):
-        """Function to reterive a paginated list of users from
-        users list. This will return a json object, which contains
-        list of users and the total number of users in users table.
-
-        Args:
-            order (str): column name to order the select by this column
-            start (int): start number to begin the query from
-            limit (int): number of rows to reterive
-        Returns:
-            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
-        """
-        users = yield self.runInteraction(
-            "get_users_paginate",
-            self._simple_select_list_paginate_txn,
-            table="users",
-            keyvalues={"is_guest": False},
-            orderby=order,
-            start=start,
-            limit=limit,
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
-        )
-        count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
-        retval = {"users": users, "total": count}
-        return retval
-
-    def search_users(self, term):
-        """Function to search users list for one or more users with
-        the matched term.
-
-        Args:
-            term (str): search term
-            col (str): column to query term should be matched to
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self._simple_search_list(
-            table="users",
-            term=term,
-            col="name",
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
-            desc="search_users",
-        )
-
-
-def are_all_users_on_domain(txn, database_engine, domain):
-    sql = database_engine.convert_param_style(
-        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
-    )
-    pat = "%:" + domain
-    txn.execute(sql, (pat,))
-    num_not_matching = txn.fetchall()[0][0]
-    if num_not_matching == 0:
-        return True
-    return False
+"""
+The storage layer is split up into multiple parts to allow Synapse to run
+against different configurations of databases (e.g. single or multiple
+databases). The `Database` class represents a single physical database. The
+`data_stores` are classes that talk directly to a `Database` instance and have
+associated schemas, background updates, etc. On top of those there are classes
+that provide high level interfaces that combine calls to multiple `data_stores`.
+
+There are also schemas that get applied to every database, regardless of the
+data stores associated with them (e.g. the schema version tables), which are
+stored in `synapse.storage.schema`.
+"""
+
+from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main import DataStore
+from synapse.storage.persist_events import EventsPersistenceStorage
+from synapse.storage.purge_events import PurgeEventsStorage
+from synapse.storage.state import StateGroupStorage
+
+__all__ = ["DataStores", "DataStore"]
+
+
+class Storage(object):
+    """The high level interfaces for talking to various storage layers.
+    """
+
+    def __init__(self, hs, stores: DataStores):
+        # We include the main data store here mainly so that we don't have to
+        # rewrite all the existing code to split it into high vs low level
+        # interfaces.
+        self.main = stores.main
+
+        self.persistence = EventsPersistenceStorage(hs, stores)
+        self.purge_events = PurgeEventsStorage(hs, stores)
+        self.state = StateGroupStorage(hs, stores)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abe16334ec..bfce541ca7 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,1403 +14,38 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import itertools
 import logging
 import random
-import sys
-import threading
-import time
-
-from six import PY2, iteritems, iterkeys, itervalues
-from six.moves import builtins, intern, range
+from abc import ABCMeta
+from typing import Any, Optional
 
 from canonicaljson import json
-from prometheus_client import Histogram
-
-from twisted.internet import defer
-
-from synapse.api.errors import StoreError
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import Cache
-from synapse.util.stringutils import exception_to_unicode
 
-# import a function which will return a monotonic time, in seconds
-try:
-    # on python 3, use time.monotonic, since time.clock can go backwards
-    from time import monotonic as monotonic_time
-except ImportError:
-    # ... but python 2 doesn't have it
-    from time import clock as monotonic_time
+from synapse.storage.database import LoggingTransaction  # noqa: F401
+from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
+from synapse.storage.database import Database
+from synapse.types import Collection, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
-try:
-    MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
-    # python 3 does not have a maximum int value
-    MAX_TXN_ID = 2 ** 63 - 1
-
-sql_logger = logging.getLogger("synapse.storage.SQL")
-transaction_logger = logging.getLogger("synapse.storage.txn")
-perf_logger = logging.getLogger("synapse.storage.TIME")
-
-sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-
-sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
-
-
-# Unique indexes which have been added in background updates. Maps from table name
-# to the name of the background update which added the unique index to that table.
-#
-# This is used by the upsert logic to figure out which tables are safe to do a proper
-# UPSERT on: until the relevant background update has completed, we
-# have to emulate an upsert by locking the table.
-#
-UNIQUE_INDEX_BACKGROUND_UPDATES = {
-    "user_ips": "user_ips_device_unique_index",
-    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
-    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
-    "event_search": "event_search_event_id_idx",
-}
-
-# This is a special cache name we use to batch multiple invalidations of caches
-# based on the current state when notifying workers over replication.
-_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-
 
-class LoggingTransaction(object):
-    """An object that almost-transparently proxies for the 'txn' object
-    passed to the constructor. Adds logging and metrics to the .execute()
-    method.
+# some of our subclasses have abstract methods, so we use the ABCMeta metaclass.
+class SQLBaseStore(metaclass=ABCMeta):
+    """Base class for data stores that holds helper functions.
 
-    Args:
-        txn: The database transcation object to wrap.
-        name (str): The name of this transactions for logging.
-        database_engine (Sqlite3Engine|PostgresEngine)
-        after_callbacks(list|None): A list that callbacks will be appended to
-            that have been added by `call_after` which should be run on
-            successful completion of the transaction. None indicates that no
-            callbacks should be allowed to be scheduled to run.
-        exception_callbacks(list|None): A list that callbacks will be appended
-            to that have been added by `call_on_exception` which should be run
-            if transaction ends with an error. None indicates that no callbacks
-            should be allowed to be scheduled to run.
+    Note that multiple instances of this class will exist as there will be one
+    per data store (and not one per physical database).
     """
 
-    __slots__ = [
-        "txn",
-        "name",
-        "database_engine",
-        "after_callbacks",
-        "exception_callbacks",
-    ]
-
-    def __init__(
-        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
-    ):
-        object.__setattr__(self, "txn", txn)
-        object.__setattr__(self, "name", name)
-        object.__setattr__(self, "database_engine", database_engine)
-        object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "exception_callbacks", exception_callbacks)
-
-    def call_after(self, callback, *args, **kwargs):
-        """Call the given callback on the main twisted thread after the
-        transaction has finished. Used to invalidate the caches on the
-        correct thread.
-        """
-        self.after_callbacks.append((callback, args, kwargs))
-
-    def call_on_exception(self, callback, *args, **kwargs):
-        self.exception_callbacks.append((callback, args, kwargs))
-
-    def __getattr__(self, name):
-        return getattr(self.txn, name)
-
-    def __setattr__(self, name, value):
-        setattr(self.txn, name, value)
-
-    def __iter__(self):
-        return self.txn.__iter__()
-
-    def execute_batch(self, sql, args):
-        if isinstance(self.database_engine, PostgresEngine):
-            from psycopg2.extras import execute_batch
-
-            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
-        else:
-            for val in args:
-                self.execute(sql, val)
-
-    def execute(self, sql, *args):
-        self._do_execute(self.txn.execute, sql, *args)
-
-    def executemany(self, sql, *args):
-        self._do_execute(self.txn.executemany, sql, *args)
-
-    def _make_sql_one_line(self, sql):
-        "Strip newlines out of SQL so that the loggers in the DB are on one line"
-        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
-
-    def _do_execute(self, func, sql, *args):
-        sql = self._make_sql_one_line(sql)
-
-        # TODO(paul): Maybe use 'info' and 'debug' for values?
-        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
-
-        sql = self.database_engine.convert_param_style(sql)
-        if args:
-            try:
-                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
-            except Exception:
-                # Don't let logging failures stop SQL from working
-                pass
-
-        start = time.time()
-
-        try:
-            return func(sql, *args)
-        except Exception as e:
-            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
-            raise
-        finally:
-            secs = time.time() - start
-            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
-            sql_query_timer.labels(sql.split()[0]).observe(secs)
-
-
-class PerformanceCounters(object):
-    def __init__(self):
-        self.current_counters = {}
-        self.previous_counters = {}
-
-    def update(self, key, duration_secs):
-        count, cum_time = self.current_counters.get(key, (0, 0))
-        count += 1
-        cum_time += duration_secs
-        self.current_counters[key] = (count, cum_time)
-
-    def interval(self, interval_duration_secs, limit=3):
-        counters = []
-        for name, (count, cum_time) in iteritems(self.current_counters):
-            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
-            counters.append(
-                (
-                    (cum_time - prev_time) / interval_duration_secs,
-                    count - prev_count,
-                    name,
-                )
-            )
-
-        self.previous_counters = dict(self.current_counters)
-
-        counters.sort(reverse=True)
-
-        top_n_counters = ", ".join(
-            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
-            for ratio, count, name in counters[:limit]
-        )
-
-        return top_n_counters
-
-
-class SQLBaseStore(object):
-    _TXN_ID = 0
-
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self._db_pool = hs.get_db_pool()
-
-        self._previous_txn_total_time = 0
-        self._current_txn_total_time = 0
-        self._previous_loop_ts = 0
-
-        # TODO(paul): These can eventually be removed once the metrics code
-        #   is running in mainline, and we have some nice monitoring frontends
-        #   to watch it
-        self._txn_perf_counters = PerformanceCounters()
-
-        self._get_event_cache = Cache(
-            "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
-        )
-
-        self._event_fetch_lock = threading.Condition()
-        self._event_fetch_list = []
-        self._event_fetch_ongoing = 0
-
-        self._pending_ds = []
-
-        self.database_engine = hs.database_engine
-
-        # A set of tables that are not safe to use native upserts in.
-        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
-
-        self._account_validity = self.hs.config.account_validity
-
-        # We add the user_directory_search table to the blacklist on SQLite
-        # because the existing search table does not have an index, making it
-        # unsafe to use native upserts.
-        if isinstance(self.database_engine, Sqlite3Engine):
-            self._unsafe_to_upsert_tables.add("user_directory_search")
-
-        if self.database_engine.can_native_upsert:
-            # Check ASAP (and then later, every 1s) to see if we have finished
-            # background updates of tables that aren't safe to update.
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "upsert_safety_check",
-                self._check_safe_to_upsert,
-            )
-
+        self.database_engine = database.engine
+        self.db = database
         self.rand = random.SystemRandom()
 
-        if self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
-            )
-
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
-        """
-        Is it safe to use native UPSERT?
-
-        If there are background updates, we will need to wait, as they may be
-        the addition of indexes that set the UNIQUE constraint that we require.
-
-        If the background updates have not completed, wait 15 sec and check again.
-        """
-        updates = yield self._simple_select_list(
-            "background_updates",
-            keyvalues=None,
-            retcols=["update_name"],
-            desc="check_background_updates",
-        )
-        updates = [x["update_name"] for x in updates]
-
-        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
-            if update_name not in updates:
-                logger.debug("Now safe to upsert in %s", table)
-                self._unsafe_to_upsert_tables.discard(table)
-
-        # If there's any updates still running, reschedule to run.
-        if updates:
-            self._clock.call_later(
-                15.0,
-                run_as_background_process,
-                "upsert_safety_check",
-                self._check_safe_to_upsert,
-            )
-
-    @defer.inlineCallbacks
-    def _set_expiration_date_when_missing(self):
-        """
-        Retrieves the list of registered users that don't have an expiration date, and
-        adds an expiration date for each of them.
-        """
-
-        def select_users_with_no_expiration_date_txn(txn):
-            """Retrieves the list of registered users with no expiration date from the
-            database, filtering out deactivated users.
-            """
-            sql = (
-                "SELECT users.name FROM users"
-                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
-                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
-            )
-            txn.execute(sql, [])
-
-            res = self.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
-
-        yield self.runInteraction(
-            "get_users_with_no_expiration_date",
-            select_users_with_no_expiration_date_txn,
-        )
-
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
-        """Sets an expiration date to the account with the given user ID.
-
-        Args:
-             user_id (str): User ID to set an expiration date for.
-             use_delta (bool): If set to False, the expiration date for the user will be
-                now + validity period. If set to True, this expiration date will be a
-                random value in the [now + period - d ; now + period] range, d being a
-                delta equal to 10% of the validity period.
-        """
-        now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
-
-        if use_delta:
-            expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
-                expiration_ts,
-            )
-
-        self._simple_insert_txn(
-            txn,
-            "account_validity",
-            values={
-                "user_id": user_id,
-                "expiration_ts_ms": expiration_ts,
-                "email_sent": False,
-            },
-        )
-
-    def start_profiling(self):
-        self._previous_loop_ts = monotonic_time()
-
-        def loop():
-            curr = self._current_txn_total_time
-            prev = self._previous_txn_total_time
-            self._previous_txn_total_time = curr
-
-            time_now = monotonic_time()
-            time_then = self._previous_loop_ts
-            self._previous_loop_ts = time_now
-
-            duration = time_now - time_then
-            ratio = (curr - prev) / duration
-
-            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
-
-            perf_logger.info(
-                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
-            )
-
-        self._clock.looping_call(loop, 10000)
-
-    def _new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
-        start = monotonic_time()
-        txn_id = self._TXN_ID
-
-        # We don't really need these to be unique, so lets stop it from
-        # growing really large.
-        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
-
-        name = "%s-%x" % (desc, txn_id)
-
-        transaction_logger.debug("[TXN START] {%s}", name)
-
-        try:
-            i = 0
-            N = 5
-            while True:
-                try:
-                    txn = conn.cursor()
-                    txn = LoggingTransaction(
-                        txn,
-                        name,
-                        self.database_engine,
-                        after_callbacks,
-                        exception_callbacks,
-                    )
-                    r = func(txn, *args, **kwargs)
-                    conn.commit()
-                    return r
-                except self.database_engine.module.OperationalError as e:
-                    # This can happen if the database disappears mid
-                    # transaction.
-                    logger.warning(
-                        "[TXN OPERROR] {%s} %s %d/%d",
-                        name,
-                        exception_to_unicode(e),
-                        i,
-                        N,
-                    )
-                    if i < N:
-                        i += 1
-                        try:
-                            conn.rollback()
-                        except self.database_engine.module.Error as e1:
-                            logger.warning(
-                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
-                            )
-                        continue
-                    raise
-                except self.database_engine.module.DatabaseError as e:
-                    if self.database_engine.is_deadlock(e):
-                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
-                        if i < N:
-                            i += 1
-                            try:
-                                conn.rollback()
-                            except self.database_engine.module.Error as e1:
-                                logger.warning(
-                                    "[TXN EROLL] {%s} %s",
-                                    name,
-                                    exception_to_unicode(e1),
-                                )
-                            continue
-                    raise
-        except Exception as e:
-            logger.debug("[TXN FAIL] {%s} %s", name, e)
-            raise
-        finally:
-            end = monotonic_time()
-            duration = end - start
-
-            LoggingContext.current_context().add_database_transaction(duration)
-
-            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
-
-            self._current_txn_total_time += duration
-            self._txn_perf_counters.update(desc, duration)
-            sql_txn_timer.labels(desc).observe(duration)
-
-    @defer.inlineCallbacks
-    def runInteraction(self, desc, func, *args, **kwargs):
-        """Starts a transaction on the database and runs a given function
-
-        Arguments:
-            desc (str): description of the transaction, for logging and metrics
-            func (func): callback function, which will be called with a
-                database transaction (twisted.enterprise.adbapi.Transaction) as
-                its first argument, followed by `args` and `kwargs`.
-
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
-
-        Returns:
-            Deferred: The result of func
-        """
-        after_callbacks = []
-        exception_callbacks = []
-
-        if LoggingContext.current_context() == LoggingContext.sentinel:
-            logger.warn("Starting db txn '%s' from sentinel context", desc)
-
-        try:
-            result = yield self.runWithConnection(
-                self._new_transaction,
-                desc,
-                after_callbacks,
-                exception_callbacks,
-                func,
-                *args,
-                **kwargs
-            )
-
-            for after_callback, after_args, after_kwargs in after_callbacks:
-                after_callback(*after_args, **after_kwargs)
-        except:  # noqa: E722, as we reraise the exception this is fine.
-            for after_callback, after_args, after_kwargs in exception_callbacks:
-                after_callback(*after_args, **after_kwargs)
-            raise
-
-        return result
-
-    @defer.inlineCallbacks
-    def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runWithConnection() method on the underlying db_pool.
-
-        Arguments:
-            func (func): callback function, which will be called with a
-                database connection (twisted.enterprise.adbapi.Connection) as
-                its first argument, followed by `args` and `kwargs`.
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
-
-        Returns:
-            Deferred: The result of func
-        """
-        parent_context = LoggingContext.current_context()
-        if parent_context == LoggingContext.sentinel:
-            logger.warn(
-                "Starting db connection from sentinel context: metrics will be lost"
-            )
-            parent_context = None
-
-        start_time = monotonic_time()
-
-        def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runWithConnection", parent_context) as context:
-                sched_duration_sec = monotonic_time() - start_time
-                sql_scheduling_timer.observe(sched_duration_sec)
-                context.add_database_scheduled(sched_duration_sec)
-
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
-
-                return func(conn, *args, **kwargs)
-
-        with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
-
-        return result
-
-    @staticmethod
-    def cursor_to_dict(cursor):
-        """Converts a SQL cursor into an list of dicts.
-
-        Args:
-            cursor : The DBAPI cursor which has executed a query.
-        Returns:
-            A list of dicts where the key is the column header.
-        """
-        col_headers = list(intern(str(column[0])) for column in cursor.description)
-        results = list(dict(zip(col_headers, row)) for row in cursor)
-        return results
-
-    def _execute(self, desc, decoder, query, *args):
-        """Runs a single query for a result set.
-
-        Args:
-            decoder - The function which can resolve the cursor results to
-                something meaningful.
-            query - The query string to execute
-            *args - Query args.
-        Returns:
-            The result of decoder(results)
-        """
-
-        def interaction(txn):
-            txn.execute(query, args)
-            if decoder:
-                return decoder(txn)
-            else:
-                return txn.fetchall()
-
-        return self.runInteraction(desc, interaction)
-
-    # "Simple" SQL API methods that operate on a single table with no JOINs,
-    # no complex WHERE clauses, just a dict of values for columns.
-
-    @defer.inlineCallbacks
-    def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
-        """Executes an INSERT query on the named table.
-
-        Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
-                when a conflicting row already exists. If True, False will be
-                returned by the function instead
-            desc : string giving a description of the transaction
-
-        Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
-        """
-        try:
-            yield self.runInteraction(desc, self._simple_insert_txn, table, values)
-        except self.database_engine.module.IntegrityError:
-            # We have to do or_ignore flag at this layer, since we can't reuse
-            # a cursor after we receive an error from the db.
-            if not or_ignore:
-                raise
-            return False
-        return True
-
-    @staticmethod
-    def _simple_insert_txn(txn, table, values):
-        keys, vals = zip(*values.items())
-
-        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
-            table,
-            ", ".join(k for k in keys),
-            ", ".join("?" for _ in keys),
-        )
-
-        txn.execute(sql, vals)
-
-    def _simple_insert_many(self, table, values, desc):
-        return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
-
-    @staticmethod
-    def _simple_insert_many_txn(txn, table, values):
-        if not values:
-            return
-
-        # This is a *slight* abomination to get a list of tuples of key names
-        # and a list of tuples of value names.
-        #
-        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
-        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
-        #
-        # The sort is to ensure that we don't rely on dictionary iteration
-        # order.
-        keys, vals = zip(
-            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
-        )
-
-        for k in keys:
-            if k != keys[0]:
-                raise RuntimeError("All items must have the same keys")
-
-        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
-            table,
-            ", ".join(k for k in keys[0]),
-            ", ".join("?" for _ in keys[0]),
-        )
-
-        txn.executemany(sql, vals)
-
-    @defer.inlineCallbacks
-    def _simple_upsert(
-        self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="_simple_upsert",
-        lock=True,
-    ):
-        """
-
-        `lock` should generally be set to True (the default), but can be set
-        to False if either of the following are true:
-
-        * there is a UNIQUE INDEX on the key columns. In this case a conflict
-          will cause an IntegrityError in which case this function will retry
-          the update.
-
-        * we somehow know that we are the only thread which will be updating
-          this table.
-
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        attempts = 0
-        while True:
-            try:
-                result = yield self.runInteraction(
-                    desc,
-                    self._simple_upsert_txn,
-                    table,
-                    keyvalues,
-                    values,
-                    insertion_values,
-                    lock=lock,
-                )
-                return result
-            except self.database_engine.module.IntegrityError as e:
-                attempts += 1
-                if attempts >= 5:
-                    # don't retry forever, because things other than races
-                    # can cause IntegrityErrors
-                    raise
-
-                # presumably we raced with another transaction: let's retry.
-                logger.warn(
-                    "IntegrityError when upserting into %s; retrying: %s", table, e
-                )
-
-    def _simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
-        """
-        Pick the UPSERT method which works best on the platform. Either the
-        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
-
-        Args:
-            txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
-            return self._simple_upsert_txn_native_upsert(
-                txn, table, keyvalues, values, insertion_values=insertion_values
-            )
-        else:
-            return self._simple_upsert_txn_emulated(
-                txn,
-                table,
-                keyvalues,
-                values,
-                insertion_values=insertion_values,
-                lock=lock,
-            )
-
-    def _simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
-        """
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            bool: Return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        # We need to lock the table :(, unless we're *really* careful
-        if lock:
-            self.database_engine.lock_table(txn, table)
-
-        def _getwhere(key):
-            # If the value we're passing in is None (aka NULL), we need to use
-            # IS, not =, as NULL = NULL equals NULL (False).
-            if keyvalues[key] is None:
-                return "%s IS ?" % (key,)
-            else:
-                return "%s = ?" % (key,)
-
-        if not values:
-            # If `values` is empty, then all of the values we care about are in
-            # the unique key, so there is nothing to UPDATE. We can just do a
-            # SELECT instead to see if it exists.
-            sql = "SELECT 1 FROM %s WHERE %s" % (
-                table,
-                " AND ".join(_getwhere(k) for k in keyvalues),
-            )
-            sqlargs = list(keyvalues.values())
-            txn.execute(sql, sqlargs)
-            if txn.fetchall():
-                # We have an existing record.
-                return False
-        else:
-            # First try to update.
-            sql = "UPDATE %s SET %s WHERE %s" % (
-                table,
-                ", ".join("%s = ?" % (k,) for k in values),
-                " AND ".join(_getwhere(k) for k in keyvalues),
-            )
-            sqlargs = list(values.values()) + list(keyvalues.values())
-
-            txn.execute(sql, sqlargs)
-            if txn.rowcount > 0:
-                # successfully updated at least one row.
-                return False
-
-        # We didn't find any existing rows, so insert a new one
-        allvalues = {}
-        allvalues.update(keyvalues)
-        allvalues.update(values)
-        allvalues.update(insertion_values)
-
-        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
-            table,
-            ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues),
-        )
-        txn.execute(sql, list(allvalues.values()))
-        # successfully inserted
-        return True
-
-    def _simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
-        """
-        Use the native UPSERT functionality in recent PostgreSQL versions.
-
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
-        """
-        allvalues = {}
-        allvalues.update(keyvalues)
-        allvalues.update(insertion_values)
-
-        if not values:
-            latter = "NOTHING"
-        else:
-            allvalues.update(values)
-            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
-
-        sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
-            table,
-            ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues),
-            ", ".join(k for k in keyvalues),
-            latter,
-        )
-        txn.execute(sql, list(allvalues.values()))
-
-    def _simple_upsert_many_txn(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
-            return self._simple_upsert_many_txn_native_upsert(
-                txn, table, key_names, key_values, value_names, value_values
-            )
-        else:
-            return self._simple_upsert_many_txn_emulated(
-                txn, table, key_names, key_values, value_names, value_values
-            )
-
-    def _simple_upsert_many_txn_emulated(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times, but without native UPSERT support or batching.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        # No value columns, therefore make a blank list so that the following
-        # zip() works correctly.
-        if not value_names:
-            value_values = [() for x in range(len(key_values))]
-
-        for keyv, valv in zip(key_values, value_values):
-            _keys = {x: y for x, y in zip(key_names, keyv)}
-            _vals = {x: y for x, y in zip(value_names, valv)}
-
-            self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
-
-    def _simple_upsert_many_txn_native_upsert(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times, using batching where possible.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        allnames = []
-        allnames.extend(key_names)
-        allnames.extend(value_names)
-
-        if not value_names:
-            # No value columns, therefore make a blank list so that the
-            # following zip() works correctly.
-            latter = "NOTHING"
-            value_values = [() for x in range(len(key_values))]
-        else:
-            latter = "UPDATE SET " + ", ".join(
-                k + "=EXCLUDED." + k for k in value_names
-            )
-
-        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
-            table,
-            ", ".join(k for k in allnames),
-            ", ".join("?" for _ in allnames),
-            ", ".join(key_names),
-            latter,
-        )
-
-        args = []
-
-        for x, y in zip(key_values, value_values):
-            args.append(tuple(x) + tuple(y))
-
-        return txn.execute_batch(sql, args)
-
-    def _simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
-    ):
-        """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning multiple columns from it.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
-        """
-        return self.runInteraction(
-            desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
-        )
-
-    def _simple_select_one_onecol(
-        self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="_simple_select_one_onecol",
-    ):
-        """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning a single column from it.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
-        """
-        return self.runInteraction(
-            desc,
-            self._simple_select_one_onecol_txn,
-            table,
-            keyvalues,
-            retcol,
-            allow_none=allow_none,
-        )
-
-    @classmethod
-    def _simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
-        ret = cls._simple_select_onecol_txn(
-            txn, table=table, keyvalues=keyvalues, retcol=retcol
-        )
-
-        if ret:
-            return ret[0]
-        else:
-            if allow_none:
-                return None
-            else:
-                raise StoreError(404, "No row found")
-
-    @staticmethod
-    def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
-        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
-
-        if keyvalues:
-            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
-            txn.execute(sql, list(keyvalues.values()))
-        else:
-            txn.execute(sql)
-
-        return [r[0] for r in txn]
-
-    def _simple_select_onecol(
-        self, table, keyvalues, retcol, desc="_simple_select_onecol"
-    ):
-        """Executes a SELECT query on the named table, which returns a list
-        comprising of the values of the named column from the selected rows.
-
-        Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
-
-        Returns:
-            Deferred: Results in a list
-        """
-        return self.runInteraction(
-            desc, self._simple_select_onecol_txn, table, keyvalues, retcol
-        )
-
-    def _simple_select_list(
-        self, table, keyvalues, retcols, desc="_simple_select_list"
-    ):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc, self._simple_select_list_txn, table, keyvalues, retcols
-        )
-
-    @classmethod
-    def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
-        """
-        if keyvalues:
-            sql = "SELECT %s FROM %s WHERE %s" % (
-                ", ".join(retcols),
-                table,
-                " AND ".join("%s = ?" % (k,) for k in keyvalues),
-            )
-            txn.execute(sql, list(keyvalues.values()))
-        else:
-            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-            txn.execute(sql)
-
-        return cls.cursor_to_dict(txn)
-
-    @defer.inlineCallbacks
-    def _simple_select_many_batch(
-        self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="_simple_select_many_batch",
-        batch_size=100,
-    ):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
-        """
-        results = []
-
-        if not iterable:
-            return results
-
-        # iterables can not be sliced, so convert it to a list first
-        it_list = list(iterable)
-
-        chunks = [
-            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
-        ]
-        for chunk in chunks:
-            rows = yield self.runInteraction(
-                desc,
-                self._simple_select_many_txn,
-                table,
-                column,
-                chunk,
-                keyvalues,
-                retcols,
-            )
-
-            results.extend(rows)
-
-        return results
-
-    @classmethod
-    def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
-        """
-        if not iterable:
-            return []
-
-        sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-
-        clauses = []
-        values = []
-        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
-        values.extend(iterable)
-
-        for key, value in iteritems(keyvalues):
-            clauses.append("%s = ?" % (key,))
-            values.append(value)
-
-        if clauses:
-            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
-
-        txn.execute(sql, values)
-        return cls.cursor_to_dict(txn)
-
-    def _simple_update(self, table, keyvalues, updatevalues, desc):
-        return self.runInteraction(
-            desc, self._simple_update_txn, table, keyvalues, updatevalues
-        )
-
-    @staticmethod
-    def _simple_update_txn(txn, table, keyvalues, updatevalues):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
-        else:
-            where = ""
-
-        update_sql = "UPDATE %s SET %s %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in updatevalues),
-            where,
-        )
-
-        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
-
-        return txn.rowcount
-
-    def _simple_update_one(
-        self, table, keyvalues, updatevalues, desc="_simple_update_one"
-    ):
-        """Executes an UPDATE query on the named table, setting new values for
-        columns in a row matching the key values.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
-        """
-        return self.runInteraction(
-            desc, self._simple_update_one_txn, table, keyvalues, updatevalues
-        )
-
-    @classmethod
-    def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
-        rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
-
-        if rowcount == 0:
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-    @staticmethod
-    def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
-        select_sql = "SELECT %s FROM %s WHERE %s" % (
-            ", ".join(retcols),
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(select_sql, list(keyvalues.values()))
-        row = txn.fetchone()
-
-        if not row:
-            if allow_none:
-                return None
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if txn.rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-        return dict(zip(retcols, row))
-
-    def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
-        """Executes a DELETE query on the named table, expecting to delete a
-        single row.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-        """
-        return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
-
-    @staticmethod
-    def _simple_delete_one_txn(txn, table, keyvalues):
-        """Executes a DELETE query on the named table, expecting to delete a
-        single row.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-        """
-        sql = "DELETE FROM %s WHERE %s" % (
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(sql, list(keyvalues.values()))
-        if txn.rowcount == 0:
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if txn.rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-    def _simple_delete(self, table, keyvalues, desc):
-        return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
-
-    @staticmethod
-    def _simple_delete_txn(txn, table, keyvalues):
-        sql = "DELETE FROM %s WHERE %s" % (
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(sql, list(keyvalues.values()))
-        return txn.rowcount
-
-    def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
-        return self.runInteraction(
-            desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
-        )
-
-    @staticmethod
-    def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
-        """Executes a DELETE query on the named table.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-
-        Returns:
-            int: Number rows deleted
-        """
-        if not iterable:
-            return 0
-
-        sql = "DELETE FROM %s" % table
-
-        clauses = []
-        values = []
-        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
-        values.extend(iterable)
-
-        for key, value in iteritems(keyvalues):
-            clauses.append("%s = ?" % (key,))
-            values.append(value)
-
-        if clauses:
-            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
-        txn.execute(sql, values)
-
-        return txn.rowcount
-
-    def _get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
-        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
-        # It doesn't really matter how many we get, the StreamChangeCache will
-        # do the right thing to ensure it respects the max size of cache.
-        sql = (
-            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
-            " WHERE %(stream)s > ? - %(limit)s"
-            " GROUP BY %(entity)s"
-        ) % {
-            "table": table,
-            "entity": entity_column,
-            "stream": stream_column,
-            "limit": limit,
-        }
-
-        sql = self.database_engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
-        txn.execute(sql, (int(max_value),))
-
-        cache = {row[0]: int(row[1]) for row in txn}
-
-        txn.close()
-
-        if cache:
-            min_val = min(itervalues(cache))
-        else:
-            min_val = max_value
-
-        return cache, min_val
-
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
-        """Invalidates the cache and adds it to the cache stream so slaves
-        will know to invalidate their caches.
-
-        This should only be used to invalidate caches where slaves won't
-        otherwise know from other replication streams that the cache should
-        be invalidated.
-        """
-        txn.call_after(cache_func.invalidate, keys)
-        self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
-
-    def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
-        """Special case invalidation of caches based on current state.
-
-        We special case this so that we can batch the cache invalidations into a
-        single replication poke.
-
-        Args:
-            txn
-            room_id (str): Room where state changed
-            members_changed (iterable[str]): The user_ids of members that have changed
-        """
-        txn.call_after(self._invalidate_state_caches, room_id, members_changed)
-
-        if members_changed:
-            # We need to be careful that the size of the `members_changed` list
-            # isn't so large that it causes problems sending over replication, so we
-            # send them in chunks.
-            # Max line length is 16K, and max user ID length is 255, so 50 should
-            # be safe.
-            for chunk in batch_iter(members_changed, 50):
-                keys = itertools.chain([room_id], chunk)
-                self._send_invalidation_to_replication(
-                    txn, _CURRENT_STATE_CACHE_NAME, keys
-                )
-        else:
-            # if no members changed, we still need to invalidate the other caches.
-            self._send_invalidation_to_replication(
-                txn, _CURRENT_STATE_CACHE_NAME, [room_id]
-            )
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        pass
 
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
@@ -1421,7 +56,7 @@ class SQLBaseStore(object):
             members_changed (iterable[str]): The user_ids of members that have
                 changed
         """
-        for host in set(get_domain_from_id(u) for u in members_changed):
+        for host in {get_domain_from_id(u) for u in members_changed}:
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
             self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
 
@@ -1429,242 +64,29 @@ class SQLBaseStore(object):
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
         self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
 
-    def _attempt_to_invalidate_cache(self, cache_name, key):
+    def _attempt_to_invalidate_cache(
+        self, cache_name: str, key: Optional[Collection[Any]]
+    ):
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
 
         Args:
-            cache_name (str)
-            key (tuple)
+            cache_name
+            key: Entry to invalidate. If None then invalidates the entire
+                cache.
         """
+
         try:
-            getattr(self, cache_name).invalidate(key)
+            if key is None:
+                getattr(self, cache_name).invalidate_all()
+            else:
+                getattr(self, cache_name).invalidate(tuple(key))
         except AttributeError:
             # We probably haven't pulled in the cache in this worker,
             # which is fine.
             pass
 
-    def _send_invalidation_to_replication(self, txn, cache_name, keys):
-        """Notifies replication that given cache has been invalidated.
-
-        Note that this does *not* invalidate the cache locally.
-
-        Args:
-            txn
-            cache_name (str)
-            keys (iterable[str])
-        """
-
-        if isinstance(self.database_engine, PostgresEngine):
-            # get_next() returns a context manager which is designed to wrap
-            # the transaction. However, we want to only get an ID when we want
-            # to use it, here, so we need to call __enter__ manually, and have
-            # __exit__ called after the transaction finishes.
-            ctx = self._cache_id_gen.get_next()
-            stream_id = ctx.__enter__()
-            txn.call_on_exception(ctx.__exit__, None, None, None)
-            txn.call_after(ctx.__exit__, None, None, None)
-            txn.call_after(self.hs.get_notifier().on_new_replication_data)
-
-            self._simple_insert_txn(
-                txn,
-                table="cache_invalidation_stream",
-                values={
-                    "stream_id": stream_id,
-                    "cache_func": cache_name,
-                    "keys": list(keys),
-                    "invalidation_ts": self.clock.time_msec(),
-                },
-            )
-
-    def get_all_updated_caches(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_caches_txn(txn):
-            # We purposefully don't bound by the current token, as we want to
-            # send across cache invalidations as quickly as possible. Cache
-            # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
-            return txn.fetchall()
-
-        return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
-
-    def get_cache_stream_token(self):
-        if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
-        else:
-            return 0
-
-    def _simple_select_list_paginate(
-        self,
-        table,
-        keyvalues,
-        orderby,
-        start,
-        limit,
-        retcols,
-        order_direction="ASC",
-        desc="_simple_select_list_paginate",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc,
-            self._simple_select_list_paginate_txn,
-            table,
-            keyvalues,
-            orderby,
-            start,
-            limit,
-            retcols,
-            order_direction=order_direction,
-        )
-
-    @classmethod
-    def _simple_select_list_paginate_txn(
-        cls,
-        txn,
-        table,
-        keyvalues,
-        orderby,
-        start,
-        limit,
-        retcols,
-        order_direction="ASC",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        if order_direction not in ["ASC", "DESC"]:
-            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
-
-        if keyvalues:
-            where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
-        else:
-            where_clause = ""
-
-        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
-            ", ".join(retcols),
-            table,
-            where_clause,
-            orderby,
-            order_direction,
-        )
-        txn.execute(sql, list(keyvalues.values()) + [limit, start])
-
-        return cls.cursor_to_dict(txn)
-
-    def get_user_count_txn(self, txn):
-        """Get a total number of registered users in the users list.
-
-        Args:
-            txn : Transaction object
-        Returns:
-            int : number of users
-        """
-        sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
-        txn.execute(sql_count)
-        return txn.fetchone()[0]
-
-    def _simple_search_list(
-        self, table, term, col, retcols, desc="_simple_search_list"
-    ):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
-        """
-
-        return self.runInteraction(
-            desc, self._simple_search_list_txn, table, term, col, retcols
-        )
-
-    @classmethod
-    def _simple_search_list_txn(cls, txn, table, term, col, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
-        """
-        if term:
-            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
-            termvalues = ["%%" + term + "%%"]
-            txn.execute(sql, termvalues)
-        else:
-            return 0
-
-        return cls.cursor_to_dict(txn)
-
-    @property
-    def database_engine_name(self):
-        return self.database_engine.module.__name__
-
-    def get_server_version(self):
-        """Returns a string describing the server version number"""
-        return self.database_engine.server_version
-
-
-class _RollbackButIsFineException(Exception):
-    """ This exception is used to rollback a transaction without implying
-    something went wrong.
-    """
-
-    pass
-
 
 def db_to_json(db_content):
     """
@@ -1678,11 +100,6 @@ def db_to_json(db_content):
     if isinstance(db_content, memoryview):
         db_content = db_content.tobytes()
 
-    # psycopg2 on Python 2 returns buffer objects, which we need to cast to
-    # bytes to decode
-    if PY2 and isinstance(db_content, builtins.buffer):
-        db_content = bytes(db_content)
-
     # Decode it to a Unicode string before feeding it to json.loads, so we
     # consistenty get a Unicode-containing object out.
     if isinstance(db_content, (bytes, bytearray)):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index e5f0668f09..59f3394b0a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 from canonicaljson import json
 
@@ -22,7 +23,6 @@ from twisted.internet import defer
 from synapse.metrics.background_process_metrics import run_as_background_process
 
 from . import engines
-from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -74,7 +74,7 @@ class BackgroundUpdatePerformance(object):
             return float(self.total_item_count) / float(self.total_duration_ms)
 
 
-class BackgroundUpdateStore(SQLBaseStore):
+class BackgroundUpdater(object):
     """ Background updates are updates to the database that run in the
     background. Each update processes a batch of data at once. We attempt to
     limit the impact of each update by monitoring how long each batch takes to
@@ -86,30 +86,34 @@ class BackgroundUpdateStore(SQLBaseStore):
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, db_conn, hs):
-        super(BackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, hs, database):
+        self._clock = hs.get_clock()
+        self.db = database
+
+        # if a background update is currently running, its name.
+        self._current_background_update = None  # type: Optional[str]
+
         self._background_update_performance = {}
-        self._background_update_queue = []
         self._background_update_handlers = {}
         self._all_done = False
 
     def start_doing_background_updates(self):
-        run_as_background_process("background_updates", self._run_background_updates)
+        run_as_background_process("background_updates", self.run_background_updates)
 
-    @defer.inlineCallbacks
-    def _run_background_updates(self):
+    async def run_background_updates(self, sleep=True):
         logger.info("Starting background schema updates")
         while True:
-            yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+            if sleep:
+                await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
 
             try:
-                result = yield self.do_next_background_update(
+                result = await self.do_next_background_update(
                     self.BACKGROUND_UPDATE_DURATION_MS
                 )
             except Exception:
                 logger.exception("Error doing update")
             else:
-                if result is None:
+                if result:
                     logger.info(
                         "No more background updates to do."
                         " Unscheduling background update task."
@@ -117,30 +121,29 @@ class BackgroundUpdateStore(SQLBaseStore):
                     self._all_done = True
                     return None
 
-    @defer.inlineCallbacks
-    def has_completed_background_updates(self):
+    async def has_completed_background_updates(self) -> bool:
         """Check if all the background updates have completed
 
         Returns:
-            Deferred[bool]: True if all background updates have completed
+            True if all background updates have completed
         """
         # if we've previously determined that there is nothing left to do, that
         # is easy
         if self._all_done:
             return True
 
-        # obviously, if we have things in our queue, we're not done.
-        if self._background_update_queue:
+        # obviously, if we are currently processing an update, we're not done.
+        if self._current_background_update:
             return False
 
         # otherwise, check if there are updates to be run. This is important,
         # as we may be running on a worker which doesn't perform the bg updates
         # itself, but still wants to wait for them to happen.
-        updates = yield self._simple_select_onecol(
+        updates = await self.db.simple_select_onecol(
             "background_updates",
             keyvalues=None,
             retcol="1",
-            desc="check_background_updates",
+            desc="has_completed_background_updates",
         )
         if not updates:
             self._all_done = True
@@ -148,42 +151,79 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         return False
 
-    @defer.inlineCallbacks
-    def do_next_background_update(self, desired_duration_ms):
+    async def has_completed_background_update(self, update_name) -> bool:
+        """Check if the given background update has finished running.
+        """
+        if self._all_done:
+            return True
+
+        if update_name == self._current_background_update:
+            return False
+
+        update_exists = await self.db.simple_select_one_onecol(
+            "background_updates",
+            keyvalues={"update_name": update_name},
+            retcol="1",
+            desc="has_completed_background_update",
+            allow_none=True,
+        )
+
+        return not update_exists
+
+    async def do_next_background_update(self, desired_duration_ms: float) -> bool:
         """Does some amount of work on the next queued background update
 
+        Returns once some amount of work is done.
+
         Args:
             desired_duration_ms(float): How long we want to spend
                 updating.
         Returns:
-            A deferred that completes once some amount of work is done.
-            The deferred will have a value of None if there is currently
-            no more work to do.
+            True if we have finished running all the background updates, otherwise False
         """
-        if not self._background_update_queue:
-            updates = yield self._simple_select_list(
-                "background_updates",
-                keyvalues=None,
-                retcols=("update_name", "depends_on"),
+
+        def get_background_updates_txn(txn):
+            txn.execute(
+                """
+                SELECT update_name, depends_on FROM background_updates
+                ORDER BY ordering, update_name
+                """
             )
-            in_flight = set(update["update_name"] for update in updates)
-            for update in updates:
-                if update["depends_on"] not in in_flight:
-                    self._background_update_queue.append(update["update_name"])
+            return self.db.cursor_to_dict(txn)
 
-        if not self._background_update_queue:
-            # no work left to do
-            return None
+        if not self._current_background_update:
+            all_pending_updates = await self.db.runInteraction(
+                "background_updates", get_background_updates_txn,
+            )
+            if not all_pending_updates:
+                # no work left to do
+                return True
+
+            # find the first update which isn't dependent on another one in the queue.
+            pending = {update["update_name"] for update in all_pending_updates}
+            for upd in all_pending_updates:
+                depends_on = upd["depends_on"]
+                if not depends_on or depends_on not in pending:
+                    break
+                logger.info(
+                    "Not starting on bg update %s until %s is done",
+                    upd["update_name"],
+                    depends_on,
+                )
+            else:
+                # if we get to the end of that for loop, there is a problem
+                raise Exception(
+                    "Unable to find a background update which doesn't depend on "
+                    "another: dependency cycle?"
+                )
 
-        # pop from the front, and add back to the back
-        update_name = self._background_update_queue.pop(0)
-        self._background_update_queue.append(update_name)
+            self._current_background_update = upd["update_name"]
 
-        res = yield self._do_background_update(update_name, desired_duration_ms)
-        return res
+        await self._do_background_update(desired_duration_ms)
+        return False
 
-    @defer.inlineCallbacks
-    def _do_background_update(self, update_name, desired_duration_ms):
+    async def _do_background_update(self, desired_duration_ms: float) -> int:
+        update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
         update_handler = self._background_update_handlers[update_name]
@@ -203,7 +243,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         else:
             batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
 
-        progress_json = yield self._simple_select_one_onecol(
+        progress_json = await self.db.simple_select_one_onecol(
             "background_updates",
             keyvalues={"update_name": update_name},
             retcol="progress_json",
@@ -212,13 +252,13 @@ class BackgroundUpdateStore(SQLBaseStore):
         progress = json.loads(progress_json)
 
         time_start = self._clock.time_msec()
-        items_updated = yield update_handler(progress, batch_size)
+        items_updated = await update_handler(progress, batch_size)
         time_stop = self._clock.time_msec()
 
         duration_ms = time_stop - time_start
 
         logger.info(
-            "Updating %r. Updated %r items in %rms."
+            "Running background update %r. Processed %r items in %rms."
             " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
             update_name,
             items_updated,
@@ -241,7 +281,9 @@ class BackgroundUpdateStore(SQLBaseStore):
         * A dict of the current progress
         * An integer count of the number of items to update in this batch.
 
-        The handler should return a deferred integer count of items updated.
+        The handler should return a deferred or coroutine which returns an integer count
+        of items updated.
+
         The handler is responsible for updating the progress of the update.
 
         Args:
@@ -357,7 +399,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
-        if isinstance(self.database_engine, engines.PostgresEngine):
+        if isinstance(self.db.engine, engines.PostgresEngine):
             runner = create_index_psql
         elif psql_only:
             runner = None
@@ -368,33 +410,12 @@ class BackgroundUpdateStore(SQLBaseStore):
         def updater(progress, batch_size):
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
-                yield self.runWithConnection(runner)
+                yield self.db.runWithConnection(runner)
             yield self._end_background_update(update_name)
             return 1
 
         self.register_background_update_handler(update_name, updater)
 
-    def start_background_update(self, update_name, progress):
-        """Starts a background update running.
-
-        Args:
-            update_name: The update to set running.
-            progress: The initial state of the progress of the update.
-
-        Returns:
-            A deferred that completes once the task has been added to the
-            queue.
-        """
-        # Clear the background update queue so that we will pick up the new
-        # task on the next iteration of do_background_update.
-        self._background_update_queue = []
-        progress_json = json.dumps(progress)
-
-        return self._simple_insert(
-            "background_updates",
-            {"update_name": update_name, "progress_json": progress_json},
-        )
-
     def _end_background_update(self, update_name):
         """Removes a completed background update task from the queue.
 
@@ -403,13 +424,31 @@ class BackgroundUpdateStore(SQLBaseStore):
         Returns:
             A deferred that completes once the task is removed.
         """
-        self._background_update_queue = [
-            name for name in self._background_update_queue if name != update_name
-        ]
-        return self._simple_delete_one(
+        if update_name != self._current_background_update:
+            raise Exception(
+                "Cannot end background update %s which isn't currently running"
+                % update_name
+            )
+        self._current_background_update = None
+        return self.db.simple_delete_one(
             "background_updates", keyvalues={"update_name": update_name}
         )
 
+    def _background_update_progress(self, update_name: str, progress: dict):
+        """Update the progress of a background update
+
+        Args:
+            update_name: The name of the background update task
+            progress: The progress of the update.
+        """
+
+        return self.db.runInteraction(
+            "background_update_progress",
+            self._background_update_progress_txn,
+            update_name,
+            progress,
+        )
+
     def _background_update_progress_txn(self, txn, update_name, progress):
         """Update the progress of a background update
 
@@ -421,7 +460,7 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         progress_json = json.dumps(progress)
 
-        self._simple_update_one_txn(
+        self.db.simple_update_one_txn(
             txn,
             "background_updates",
             keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
new file mode 100644
index 0000000000..599ee470d4
--- /dev/null
+++ b/synapse/storage/data_stores/__init__.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.data_stores.main.events import PersistEventsStore
+from synapse.storage.data_stores.state import StateGroupDataStore
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+
+logger = logging.getLogger(__name__)
+
+
+class DataStores(object):
+    """The various data stores.
+
+    These are low level interfaces to physical databases.
+
+    Attributes:
+        main (DataStore)
+    """
+
+    def __init__(self, main_store_class, hs):
+        # Note we pass in the main store class here as workers use a different main
+        # store.
+
+        self.databases = []
+        self.main = None
+        self.state = None
+        self.persist_events = None
+
+        for database_config in hs.config.database.databases:
+            db_name = database_config.name
+            engine = create_engine(database_config.config)
+
+            with make_conn(database_config, engine) as db_conn:
+                logger.info("Preparing database %r...", db_name)
+
+                engine.check_database(db_conn)
+                prepare_database(
+                    db_conn, engine, hs.config, data_stores=database_config.data_stores,
+                )
+
+                database = Database(hs, database_config, engine)
+
+                if "main" in database_config.data_stores:
+                    logger.info("Starting 'main' data store")
+
+                    # Sanity check we don't try and configure the main store on
+                    # multiple databases.
+                    if self.main:
+                        raise Exception("'main' data store already configured")
+
+                    self.main = main_store_class(database, db_conn, hs)
+
+                    # If we're on a process that can persist events also
+                    # instantiate a `PersistEventsStore`
+                    if hs.config.worker.writers.events == hs.get_instance_name():
+                        self.persist_events = PersistEventsStore(
+                            hs, database, self.main
+                        )
+
+                if "state" in database_config.data_stores:
+                    logger.info("Starting 'state' data store")
+
+                    # Sanity check we don't try and configure the state store on
+                    # multiple databases.
+                    if self.state:
+                        raise Exception("'state' data store already configured")
+
+                    self.state = StateGroupDataStore(database, db_conn, hs)
+
+                db_conn.commit()
+
+                self.databases.append(database)
+
+                logger.info("Database %r prepared", db_name)
+
+        # Sanity check that we have actually configured all the required stores.
+        if not self.main:
+            raise Exception("No 'main' data store configured")
+
+        if not self.state:
+            raise Exception("No 'main' data store configured")
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
new file mode 100644
index 0000000000..4b4763c701
--- /dev/null
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -0,0 +1,592 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import calendar
+import logging
+import time
+
+from synapse.api.constants import PresenceState
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import (
+    IdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from .account_data import AccountDataStore
+from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .cache import CacheInvalidationWorkerStore
+from .censor_events import CensorEventsStore
+from .client_ips import ClientIpStore
+from .deviceinbox import DeviceInboxStore
+from .devices import DeviceStore
+from .directory import DirectoryStore
+from .e2e_room_keys import EndToEndRoomKeyStore
+from .end_to_end_keys import EndToEndKeyStore
+from .event_federation import EventFederationStore
+from .event_push_actions import EventPushActionsStore
+from .events_bg_updates import EventsBackgroundUpdatesStore
+from .filtering import FilteringStore
+from .group_server import GroupServerStore
+from .keys import KeyStore
+from .media_repository import MediaRepositoryStore
+from .metrics import ServerMetricsStore
+from .monthly_active_users import MonthlyActiveUsersStore
+from .openid import OpenIdStore
+from .presence import PresenceStore, UserPresenceState
+from .profile import ProfileStore
+from .purge_events import PurgeEventsStore
+from .push_rule import PushRuleStore
+from .pusher import PusherStore
+from .receipts import ReceiptsStore
+from .registration import RegistrationStore
+from .rejections import RejectionsStore
+from .relations import RelationsStore
+from .room import RoomStore
+from .roommember import RoomMemberStore
+from .search import SearchStore
+from .signatures import SignatureStore
+from .state import StateStore
+from .stats import StatsStore
+from .stream import StreamStore
+from .tags import TagsStore
+from .transactions import TransactionStore
+from .ui_auth import UIAuthStore
+from .user_directory import UserDirectoryStore
+from .user_erasure_store import UserErasureStore
+
+logger = logging.getLogger(__name__)
+
+
+class DataStore(
+    EventsBackgroundUpdatesStore,
+    RoomMemberStore,
+    RoomStore,
+    RegistrationStore,
+    StreamStore,
+    ProfileStore,
+    PresenceStore,
+    TransactionStore,
+    DirectoryStore,
+    KeyStore,
+    StateStore,
+    SignatureStore,
+    ApplicationServiceStore,
+    PurgeEventsStore,
+    EventFederationStore,
+    MediaRepositoryStore,
+    RejectionsStore,
+    FilteringStore,
+    PusherStore,
+    PushRuleStore,
+    ApplicationServiceTransactionStore,
+    ReceiptsStore,
+    EndToEndKeyStore,
+    EndToEndRoomKeyStore,
+    SearchStore,
+    TagsStore,
+    AccountDataStore,
+    EventPushActionsStore,
+    OpenIdStore,
+    ClientIpStore,
+    DeviceStore,
+    DeviceInboxStore,
+    UserDirectoryStore,
+    GroupServerStore,
+    UserErasureStore,
+    MonthlyActiveUsersStore,
+    StatsStore,
+    RelationsStore,
+    CensorEventsStore,
+    UIAuthStore,
+    CacheInvalidationWorkerStore,
+    ServerMetricsStore,
+):
+    def __init__(self, database: Database, db_conn, hs):
+        self.hs = hs
+        self._clock = hs.get_clock()
+        self.database_engine = database.engine
+
+        self._presence_id_gen = StreamIdGenerator(
+            db_conn, "presence_stream", "stream_id"
+        )
+        self._device_inbox_id_gen = StreamIdGenerator(
+            db_conn, "device_max_stream_id", "stream_id"
+        )
+        self._public_room_id_gen = StreamIdGenerator(
+            db_conn, "public_room_list_stream", "stream_id"
+        )
+        self._device_list_id_gen = StreamIdGenerator(
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
+        )
+        self._cross_signing_id_gen = StreamIdGenerator(
+            db_conn, "e2e_cross_signing_keys", "stream_id"
+        )
+
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+        )
+        self._group_updates_id_gen = StreamIdGenerator(
+            db_conn, "local_group_updates", "stream_id"
+        )
+
+        if isinstance(self.database_engine, PostgresEngine):
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name="master",
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
+            )
+        else:
+            self._cache_id_gen = None
+
+        super(DataStore, self).__init__(database, db_conn, hs)
+
+        self._presence_on_startup = self._get_active_presence(db_conn)
+
+        presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
+            db_conn,
+            "presence_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._presence_id_gen.get_current_token(),
+        )
+        self.presence_stream_cache = StreamChangeCache(
+            "PresenceStreamChangeCache",
+            min_presence_val,
+            prefilled_cache=presence_cache_prefill,
+        )
+
+        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+        device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
+            db_conn,
+            "device_inbox",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_inbox_stream_cache = StreamChangeCache(
+            "DeviceInboxStreamChangeCache",
+            min_device_inbox_id,
+            prefilled_cache=device_inbox_prefill,
+        )
+        # The federation outbox and the local device inbox uses the same
+        # stream_id generator.
+        device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
+            db_conn,
+            "device_federation_outbox",
+            entity_column="destination",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_federation_outbox_stream_cache = StreamChangeCache(
+            "DeviceFederationOutboxStreamChangeCache",
+            min_device_outbox_id,
+            prefilled_cache=device_outbox_prefill,
+        )
+
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max
+        )
+        self._user_signature_stream_cache = StreamChangeCache(
+            "UserSignatureStreamChangeCache", device_list_max
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max
+        )
+
+        events_max = self._stream_id_gen.get_current_token()
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+            db_conn,
+            "current_state_delta_stream",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=events_max,  # As we share the stream id with events token
+            limit=1000,
+        )
+        self._curr_state_delta_stream_cache = StreamChangeCache(
+            "_curr_state_delta_stream_cache",
+            min_curr_state_delta_id,
+            prefilled_cache=curr_state_delta_prefill,
+        )
+
+        _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
+            db_conn,
+            "local_group_updates",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._group_updates_id_gen.get_current_token(),
+            limit=1000,
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache",
+            min_group_updates_id,
+            prefilled_cache=_group_updates_prefill,
+        )
+
+        self._stream_order_on_start = self.get_room_max_stream_ordering()
+        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
+
+        # Used in _generate_user_daily_visits to keep track of progress
+        self._last_user_visit_update = self._get_start_of_day()
+
+    def take_presence_startup_info(self):
+        active_on_startup = self._presence_on_startup
+        self._presence_on_startup = None
+        return active_on_startup
+
+    def _get_active_presence(self, db_conn):
+        """Fetch non-offline presence from the database so that we can register
+        the appropriate time outs.
+        """
+
+        sql = (
+            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+            " WHERE state != ?"
+        )
+        sql = self.database_engine.convert_param_style(sql)
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (PresenceState.OFFLINE,))
+        rows = self.db.cursor_to_dict(txn)
+        txn.close()
+
+        for row in rows:
+            row["currently_active"] = bool(row["currently_active"])
+
+        return [UserPresenceState(**row) for row in rows]
+
+    def count_daily_users(self):
+        """
+        Counts the number of users who used this homeserver in the last 24 hours.
+        """
+        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+        return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
+
+    def count_monthly_users(self):
+        """
+        Counts the number of users who used this homeserver in the last 30 days.
+        Note this method is intended for phonehome metrics only and is different
+        from the mau figure in synapse.storage.monthly_active_users which,
+        amongst other things, includes a 3 day grace period before a user counts.
+        """
+        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+        return self.db.runInteraction(
+            "count_monthly_users", self._count_users, thirty_days_ago
+        )
+
+    def _count_users(self, txn, time_from):
+        """
+        Returns number of users seen in the past time_from period
+        """
+        sql = """
+            SELECT COALESCE(count(*), 0) FROM (
+                SELECT user_id FROM user_ips
+                WHERE last_seen > ?
+                GROUP BY user_id
+            ) u
+        """
+        txn.execute(sql, (time_from,))
+        (count,) = txn.fetchone()
+        return count
+
+    def count_r30_users(self):
+        """
+        Counts the number of 30 day retained users, defined as:-
+         * Users who have created their accounts more than 30 days ago
+         * Where last seen at most 30 days ago
+         * Where account creation and last_seen are > 30 days apart
+
+         Returns counts globaly for a given user as well as breaking
+         by platform
+        """
+
+        def _count_r30_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+            sql = """
+                SELECT platform, COALESCE(count(*), 0) FROM (
+                     SELECT
+                        users.name, platform, users.creation_ts * 1000,
+                        MAX(uip.last_seen)
+                     FROM users
+                     INNER JOIN (
+                         SELECT
+                         user_id,
+                         last_seen,
+                         CASE
+                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
+                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+                             ELSE 'unknown'
+                         END
+                         AS platform
+                         FROM user_ips
+                     ) uip
+                     ON users.name = uip.user_id
+                     AND users.appservice_id is NULL
+                     AND users.creation_ts < ?
+                     AND uip.last_seen/1000 > ?
+                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                     GROUP BY users.name, platform, users.creation_ts
+                ) u GROUP BY platform
+            """
+
+            results = {}
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            for row in txn:
+                if row[0] == "unknown":
+                    pass
+                results[row[0]] = row[1]
+
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT users.name, users.creation_ts * 1000,
+                                                        MAX(uip.last_seen)
+                    FROM users
+                    INNER JOIN (
+                        SELECT
+                        user_id,
+                        last_seen
+                        FROM user_ips
+                    ) uip
+                    ON users.name = uip.user_id
+                    AND appservice_id is NULL
+                    AND users.creation_ts < ?
+                    AND uip.last_seen/1000 > ?
+                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                    GROUP BY users.name, users.creation_ts
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            (count,) = txn.fetchone()
+            results["all"] = count
+
+            return results
+
+        return self.db.runInteraction("count_r30_users", _count_r30_users)
+
+    def _get_start_of_day(self):
+        """
+        Returns millisecond unixtime for start of UTC day.
+        """
+        now = time.gmtime()
+        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+        return today_start * 1000
+
+    def generate_user_daily_visits(self):
+        """
+        Generates daily visit data for use in cohort/ retention analysis
+        """
+
+        def _generate_user_daily_visits(txn):
+            logger.info("Calling _generate_user_daily_visits")
+            today_start = self._get_start_of_day()
+            a_day_in_milliseconds = 24 * 60 * 60 * 1000
+            now = self.clock.time_msec()
+
+            sql = """
+                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+                    SELECT u.user_id, u.device_id, ?
+                    FROM user_ips AS u
+                    LEFT JOIN (
+                      SELECT user_id, device_id, timestamp FROM user_daily_visits
+                      WHERE timestamp = ?
+                    ) udv
+                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+                    INNER JOIN users ON users.name=u.user_id
+                    WHERE last_seen > ? AND last_seen <= ?
+                    AND udv.timestamp IS NULL AND users.is_guest=0
+                    AND users.appservice_id IS NULL
+                    GROUP BY u.user_id, u.device_id
+            """
+
+            # This means that the day has rolled over but there could still
+            # be entries from the previous day. There is an edge case
+            # where if the user logs in at 23:59 and overwrites their
+            # last_seen at 00:01 then they will not be counted in the
+            # previous day's stats - it is important that the query is run
+            # often to minimise this case.
+            if today_start > self._last_user_visit_update:
+                yesterday_start = today_start - a_day_in_milliseconds
+                txn.execute(
+                    sql,
+                    (
+                        yesterday_start,
+                        yesterday_start,
+                        self._last_user_visit_update,
+                        today_start,
+                    ),
+                )
+                self._last_user_visit_update = today_start
+
+            txn.execute(
+                sql, (today_start, today_start, self._last_user_visit_update, now)
+            )
+            # Update _last_user_visit_update to now. The reason to do this
+            # rather just clamping to the beginning of the day is to limit
+            # the size of the join - meaning that the query can be run more
+            # frequently
+            self._last_user_visit_update = now
+
+        return self.db.runInteraction(
+            "generate_user_daily_visits", _generate_user_daily_visits
+        )
+
+    def get_users(self):
+        """Function to retrieve a list of users in users table.
+
+        Args:
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.db.simple_select_list(
+            table="users",
+            keyvalues={},
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin",
+                "user_type",
+                "deactivated",
+            ],
+            desc="get_users",
+        )
+
+    def get_users_paginate(
+        self, start, limit, name=None, guests=True, deactivated=False
+    ):
+        """Function to retrieve a paginated list of users from
+        users list. This will return a json list of users and the
+        total number of users matching the filter criteria.
+
+        Args:
+            start (int): start number to begin the query from
+            limit (int): number of rows to retrieve
+            name (string): filter for user names
+            guests (bool): whether to in include guest users
+            deactivated (bool): whether to include deactivated users
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]], int
+        """
+
+        def get_users_paginate_txn(txn):
+            filters = []
+            args = []
+
+            if name:
+                filters.append("name LIKE ?")
+                args.append("%" + name + "%")
+
+            if not guests:
+                filters.append("is_guest = 0")
+
+            if not deactivated:
+                filters.append("deactivated = 0")
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            args = [self.hs.config.server_name] + args + [limit, start]
+            sql = """
+                SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+                FROM users as u
+                LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
+                {}
+                ORDER BY u.name LIMIT ? OFFSET ?
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            users = self.db.cursor_to_dict(txn)
+            return users, count
+
+        return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
+
+    def search_users(self, term):
+        """Function to search users list for one or more users with
+        the matched term.
+
+        Args:
+            term (str): search term
+            col (str): column to query term should be matched to
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.db.simple_search_list(
+            table="users",
+            term=term,
+            col="name",
+            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+            desc="search_users",
+        )
+
+
+def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+    """Called before upgrading an existing database to check that it is broadly sane
+    compared with the configuration.
+    """
+    domain = config.server_name
+
+    sql = database_engine.convert_param_style(
+        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+    )
+    pat = "%:" + domain
+    cur.execute(sql, (pat,))
+    num_not_matching = cur.fetchall()[0][0]
+    if num_not_matching == 0:
+        return
+
+    raise Exception(
+        "Found users in database not native to %s!\n"
+        "You cannot changed a synapse server_name after it's been configured"
+        % (domain,)
+    )
+
+
+__all__ = ["DataStore", "check_database_before_upgrade"]
diff --git a/synapse/storage/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 6afbfc0d74..b58f04d00d 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,12 +16,14 @@
 
 import abc
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +40,13 @@ class AccountDataWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+        super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
 
     @abc.abstractmethod
     def get_max_account_data_stream_id(self):
@@ -67,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_user_txn(txn):
-            rows = self._simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "account_data",
                 {"user_id": user_id},
@@ -78,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: json.loads(row["content"]) for row in rows
             }
 
-            rows = self._simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "room_account_data",
                 {"user_id": user_id},
@@ -92,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return global_account_data, by_room
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
@@ -102,7 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         Returns:
             Deferred: A dict
         """
-        result = yield self._simple_select_one_onecol(
+        result = yield self.db.simple_select_one_onecol(
             table="account_data",
             keyvalues={"user_id": user_id, "account_data_type": data_type},
             retcol="content",
@@ -127,7 +129,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_room_txn(txn):
-            rows = self._simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "room_account_data",
                 {"user_id": user_id, "room_id": room_id},
@@ -138,7 +140,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: json.loads(row["content"]) for row in rows
             }
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
@@ -156,7 +158,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_room_and_type_txn(txn):
-            content_json = self._simple_select_one_onecol_txn(
+            content_json = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="room_account_data",
                 keyvalues={
@@ -170,45 +172,68 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return json.loads(content_json) if content_json else None
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
-    def get_all_updated_account_data(
-        self, last_global_id, last_room_id, current_id, limit
-    ):
-        """Get all the client account_data that has changed on the server
+    async def get_updated_global_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
+
         Args:
-            last_global_id(int): The position to fetch from for top level data
-            last_room_id(int): The position to fetch from for per room data
-            current_id(int): The position to fetch up to.
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
         Returns:
-            A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, type string, and content string.
+            A list of tuples of stream_id int, user_id string,
+            and type string.
         """
-        if last_room_id == current_id and last_global_id == current_id:
-            return defer.succeed(([], []))
+        if last_id == current_id:
+            return []
 
-        def get_updated_account_data_txn(txn):
+        def get_updated_global_account_data_txn(txn):
             sql = (
-                "SELECT stream_id, user_id, account_data_type, content"
+                "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_global_id, current_id, limit))
-            global_results = txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return await self.db.runInteraction(
+            "get_updated_global_account_data", get_updated_global_account_data_txn
+        )
+
+    async def get_updated_room_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
 
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
+        Returns:
+            A list of tuples of stream_id int, user_id string,
+            room_id string and type string.
+        """
+        if last_id == current_id:
+            return []
+
+        def get_updated_room_account_data_txn(txn):
             sql = (
-                "SELECT stream_id, user_id, room_id, account_data_type, content"
+                "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_room_id, current_id, limit))
-            room_results = txn.fetchall()
-            return global_results, room_results
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
 
-        return self.runInteraction(
-            "get_all_updated_account_data_txn", get_updated_account_data_txn
+        return await self.db.runInteraction(
+            "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
     def get_updated_account_data_for_user(self, user_id, stream_id):
@@ -250,9 +275,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return defer.succeed(({}, {}))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
@@ -270,12 +295,18 @@ class AccountDataWorkerStore(SQLBaseStore):
 
 
 class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = StreamIdGenerator(
-            db_conn, "account_data_max_stream_id", "stream_id"
+            db_conn,
+            "account_data_max_stream_id",
+            "stream_id",
+            extra_tables=[
+                ("room_account_data", "stream_id"),
+                ("room_tags_revisions", "stream_id"),
+            ],
         )
 
-        super(AccountDataStore, self).__init__(db_conn, hs)
+        super(AccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         """Get the current max stream id for the private user data stream
@@ -300,9 +331,9 @@ class AccountDataStore(AccountDataWorkerStore):
 
         with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
-            # on (user_id, room_id, account_data_type) so _simple_upsert will
+            # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
-            yield self._simple_upsert(
+            yield self.db.simple_upsert(
                 desc="add_room_account_data",
                 table="room_account_data",
                 keyvalues={
@@ -346,9 +377,9 @@ class AccountDataStore(AccountDataWorkerStore):
 
         with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
-            # (user_id, account_data_type) so _simple_upsert will retry if
+            # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
-            yield self._simple_upsert(
+            yield self.db.simple_upsert(
                 desc="add_user_account_data",
                 table="account_data",
                 keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -362,6 +393,10 @@ class AccountDataStore(AccountDataWorkerStore):
             # doesn't sound any worse than the whole update getting lost,
             # which is what would happen if we combined the two into one
             # transaction.
+            #
+            # Note: This is only here for backwards compat to allow admins to
+            # roll back to a previous Synapse version. Next time we update the
+            # database version we can remove this table.
             yield self._update_max_stream_id(next_id)
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
@@ -380,6 +415,10 @@ class AccountDataStore(AccountDataWorkerStore):
             next_id(int): The the revision to advance to.
         """
 
+        # Note: This is only here for backwards compat to allow admins to
+        # roll back to a previous Synapse version. Next time we update the
+        # database version we can remove this table.
+
         def _update(txn):
             update_max_id_sql = (
                 "UPDATE account_data_max_stream_id"
@@ -388,4 +427,4 @@ class AccountDataStore(AccountDataWorkerStore):
             )
             txn.execute(update_max_id_sql, (next_id, next_id))
 
-        return self.runInteraction("update_account_data_max_stream_id", _update)
+        return self.db.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 435b2acd4d..7a1fe8cdd2 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,20 +22,20 @@ from twisted.internet import defer
 
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
-from synapse.storage.events_worker import EventsWorkerStore
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
 
 def _make_exclusive_regex(services_cache):
-    # We precompie a regex constructed from all the regexes that the AS's
+    # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
         regex.pattern
         for service in services_cache
-        for regex in service.get_exlusive_user_regexes()
+        for regex in service.get_exclusive_user_regexes()
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
@@ -49,13 +49,13 @@ def _make_exclusive_regex(services_cache):
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
-        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+        super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
@@ -134,8 +134,8 @@ class ApplicationServiceTransactionWorkerStore(
             A Deferred which resolves to a list of ApplicationServices, which
             may be empty.
         """
-        results = yield self._simple_select_list(
-            "application_services_state", dict(state=state), ["as_id"]
+        results = yield self.db.simple_select_list(
+            "application_services_state", {"state": state}, ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
         as_list = self.get_app_services()
@@ -156,9 +156,9 @@ class ApplicationServiceTransactionWorkerStore(
         Returns:
             A Deferred which resolves to ApplicationServiceState.
         """
-        result = yield self._simple_select_one(
+        result = yield self.db.simple_select_one(
             "application_services_state",
-            dict(as_id=service.id),
+            {"as_id": service.id},
             ["state"],
             allow_none=True,
             desc="get_appservice_state",
@@ -176,8 +176,8 @@ class ApplicationServiceTransactionWorkerStore(
         Returns:
             A Deferred which resolves when the state was set successfully.
         """
-        return self._simple_upsert(
-            "application_services_state", dict(as_id=service.id), dict(state=state)
+        return self.db.simple_upsert(
+            "application_services_state", {"as_id": service.id}, {"state": state}
         )
 
     def create_appservice_txn(self, service, events):
@@ -217,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
             )
             return AppServiceTransaction(service=service, id=new_txn_id, events=events)
 
-        return self.runInteraction("create_appservice_txn", _create_appservice_txn)
+        return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
 
     def complete_appservice_txn(self, txn_id, service):
         """Completes an application service transaction.
@@ -250,19 +250,23 @@ class ApplicationServiceTransactionWorkerStore(
                 )
 
             # Set current txn_id for AS to 'txn_id'
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 "application_services_state",
-                dict(as_id=service.id),
-                dict(last_txn=txn_id),
+                {"as_id": service.id},
+                {"last_txn": txn_id},
             )
 
             # Delete txn
-            self._simple_delete_txn(
-                txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
+            self.db.simple_delete_txn(
+                txn,
+                "application_services_txns",
+                {"txn_id": txn_id, "as_id": service.id},
             )
 
-        return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
+        return self.db.runInteraction(
+            "complete_appservice_txn", _complete_appservice_txn
+        )
 
     @defer.inlineCallbacks
     def get_oldest_unsent_txn(self, service):
@@ -284,7 +288,7 @@ class ApplicationServiceTransactionWorkerStore(
                 " ORDER BY txn_id ASC LIMIT 1",
                 (service.id,),
             )
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return None
 
@@ -292,7 +296,7 @@ class ApplicationServiceTransactionWorkerStore(
 
             return entry
 
-        entry = yield self.runInteraction(
+        entry = yield self.db.runInteraction(
             "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
         )
 
@@ -322,7 +326,7 @@ class ApplicationServiceTransactionWorkerStore(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
@@ -351,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.runInteraction(
+        upper_bound, event_ids = yield self.db.runInteraction(
             "get_new_events_for_appservice", get_new_events_for_appservice_txn
         )
 
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
new file mode 100644
index 0000000000..eac5a4e55b
--- /dev/null
+++ b/synapse/storage/data_stores/main/cache.py
@@ -0,0 +1,280 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import logging
+from typing import Any, Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+    EventsStreamCurrentStateRow,
+    EventsStreamEventRow,
+)
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.util.iterutils import batch_iter
+
+logger = logging.getLogger(__name__)
+
+
+# This is a special cache name we use to batch multiple invalidations of caches
+# based on the current state when notifying workers over replication.
+CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
+
+
+class CacheInvalidationWorkerStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+    async def get_all_updated_caches(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ):
+        """Fetches cache invalidation rows between the two given IDs written
+        by the given instance. Returns at most `limit` rows.
+        """
+
+        if last_id == current_id:
+            return []
+
+        def get_all_updated_caches_txn(txn):
+            # We purposefully don't bound by the current token, as we want to
+            # send across cache invalidations as quickly as possible. Cache
+            # invalidations are idempotent, so duplicates are fine.
+            sql = """
+                SELECT stream_id, cache_func, keys, invalidation_ts
+                FROM cache_invalidation_stream_by_instance
+                WHERE stream_id > ? AND instance_name = ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, instance_name, limit))
+            return txn.fetchall()
+
+        return await self.db.runInteraction(
+            "get_all_updated_caches", get_all_updated_caches_txn
+        )
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "events":
+            for row in rows:
+                self._process_event_stream_row(token, row)
+        elif stream_name == "backfill":
+            for row in rows:
+                self._invalidate_caches_for_event(
+                    -token,
+                    row.event_id,
+                    row.room_id,
+                    row.type,
+                    row.state_key,
+                    row.redacts,
+                    row.relates_to,
+                    backfilled=True,
+                )
+        elif stream_name == "caches":
+            if self._cache_id_gen:
+                self._cache_id_gen.advance(instance_name, token)
+
+            for row in rows:
+                if row.cache_func == CURRENT_STATE_CACHE_NAME:
+                    if row.keys is None:
+                        raise Exception(
+                            "Can't send an 'invalidate all' for current state cache"
+                        )
+
+                    room_id = row.keys[0]
+                    members_changed = set(row.keys[1:])
+                    self._invalidate_state_caches(room_id, members_changed)
+                else:
+                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
+    def _process_event_stream_row(self, token, row):
+        data = row.data
+
+        if row.type == EventsStreamEventRow.TypeId:
+            self._invalidate_caches_for_event(
+                token,
+                data.event_id,
+                data.room_id,
+                data.type,
+                data.state_key,
+                data.redacts,
+                data.relates_to,
+                backfilled=False,
+            )
+        elif row.type == EventsStreamCurrentStateRow.TypeId:
+            self._curr_state_delta_stream_cache.entity_has_changed(
+                row.data.room_id, token
+            )
+
+            if data.type == EventTypes.Member:
+                self.get_rooms_for_user_with_stream_ordering.invalidate(
+                    (data.state_key,)
+                )
+        else:
+            raise Exception("Unknown events stream row type %s" % (row.type,))
+
+    def _invalidate_caches_for_event(
+        self,
+        stream_ordering,
+        event_id,
+        room_id,
+        etype,
+        state_key,
+        redacts,
+        relates_to,
+        backfilled,
+    ):
+        self._invalidate_get_event_cache(event_id)
+
+        self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+        if not backfilled:
+            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+        if redacts:
+            self._invalidate_get_event_cache(redacts)
+
+        if etype == EventTypes.Member:
+            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+            self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+        if relates_to:
+            self.get_relations_for_event.invalidate_many((relates_to,))
+            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_applicable_edit.invalidate((relates_to,))
+
+    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+        """Invalidates the cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+
+        This should only be used to invalidate caches where slaves won't
+        otherwise know from other replication streams that the cache should
+        be invalidated.
+        """
+        cache_func = getattr(self, cache_name, None)
+        if not cache_func:
+            return
+
+        cache_func.invalidate(keys)
+        await self.db.runInteraction(
+            "invalidate_cache_and_stream",
+            self._send_invalidation_to_replication,
+            cache_func.__name__,
+            keys,
+        )
+
+    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+        """Invalidates the cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+
+        This should only be used to invalidate caches where slaves won't
+        otherwise know from other replication streams that the cache should
+        be invalidated.
+        """
+        txn.call_after(cache_func.invalidate, keys)
+        self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+
+    def _invalidate_all_cache_and_stream(self, txn, cache_func):
+        """Invalidates the entire cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+        """
+
+        txn.call_after(cache_func.invalidate_all)
+        self._send_invalidation_to_replication(txn, cache_func.__name__, None)
+
+    def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+        """Special case invalidation of caches based on current state.
+
+        We special case this so that we can batch the cache invalidations into a
+        single replication poke.
+
+        Args:
+            txn
+            room_id (str): Room where state changed
+            members_changed (iterable[str]): The user_ids of members that have changed
+        """
+        txn.call_after(self._invalidate_state_caches, room_id, members_changed)
+
+        if members_changed:
+            # We need to be careful that the size of the `members_changed` list
+            # isn't so large that it causes problems sending over replication, so we
+            # send them in chunks.
+            # Max line length is 16K, and max user ID length is 255, so 50 should
+            # be safe.
+            for chunk in batch_iter(members_changed, 50):
+                keys = itertools.chain([room_id], chunk)
+                self._send_invalidation_to_replication(
+                    txn, CURRENT_STATE_CACHE_NAME, keys
+                )
+        else:
+            # if no members changed, we still need to invalidate the other caches.
+            self._send_invalidation_to_replication(
+                txn, CURRENT_STATE_CACHE_NAME, [room_id]
+            )
+
+    def _send_invalidation_to_replication(
+        self, txn, cache_name: str, keys: Optional[Iterable[Any]]
+    ):
+        """Notifies replication that given cache has been invalidated.
+
+        Note that this does *not* invalidate the cache locally.
+
+        Args:
+            txn
+            cache_name
+            keys: Entry to invalidate. If None will invalidate all.
+        """
+
+        if cache_name == CURRENT_STATE_CACHE_NAME and keys is None:
+            raise Exception(
+                "Can't stream invalidate all with magic current state cache"
+            )
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # get_next() returns a context manager which is designed to wrap
+            # the transaction. However, we want to only get an ID when we want
+            # to use it, here, so we need to call __enter__ manually, and have
+            # __exit__ called after the transaction finishes.
+            stream_id = self._cache_id_gen.get_next_txn(txn)
+            txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
+            if keys is not None:
+                keys = list(keys)
+
+            self.db.simple_insert_txn(
+                txn,
+                table="cache_invalidation_stream_by_instance",
+                values={
+                    "stream_id": stream_id,
+                    "instance_name": self._instance_name,
+                    "cache_func": cache_name,
+                    "keys": keys,
+                    "invalidation_ts": self.clock.time_msec(),
+                },
+            )
+
+    def get_cache_stream_token(self, instance_name):
+        if self._cache_id_gen:
+            return self._cache_id_gen.get_current_token(instance_name)
+        else:
+            return 0
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py
new file mode 100644
index 0000000000..2d48261724
--- /dev/null
+++ b/synapse/storage/data_stores/main/censor_events.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.internet import defer
+
+from synapse.events.utils import prune_event_dict
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.data_stores.main.events import encode_json
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+
+        def _censor_redactions():
+            return run_as_background_process(
+                "_censor_redactions", self._censor_redactions
+            )
+
+        if self.hs.config.redaction_retention_period is not None:
+            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+
+    async def _censor_redactions(self):
+        """Censors all redactions older than the configured period that haven't
+        been censored yet.
+
+        By censor we mean update the event_json table with the redacted event.
+        """
+
+        if self.hs.config.redaction_retention_period is None:
+            return
+
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
+        before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+        # We fetch all redactions that:
+        #   1. point to an event we have,
+        #   2. has a received_ts from before the cut off, and
+        #   3. we haven't yet censored.
+        #
+        # This is limited to 100 events to ensure that we don't try and do too
+        # much at once. We'll get called again so this should eventually catch
+        # up.
+        sql = """
+            SELECT redactions.event_id, redacts FROM redactions
+            LEFT JOIN events AS original_event ON (
+                redacts = original_event.event_id
+            )
+            WHERE NOT have_censored
+            AND redactions.received_ts <= ?
+            ORDER BY redactions.received_ts ASC
+            LIMIT ?
+        """
+
+        rows = await self.db.execute(
+            "_censor_redactions_fetch", None, sql, before_ts, 100
+        )
+
+        updates = []
+
+        for redaction_id, event_id in rows:
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
+                event_id, allow_rejected=True, allow_none=True
+            )
+
+            # The SQL above ensures that we have both the redaction and
+            # original event, so if the `get_event` calls return None it
+            # means that the redaction wasn't allowed. Either way we know that
+            # the result won't change so we mark the fact that we've checked.
+            if (
+                redaction_event
+                and original_event
+                and original_event.internal_metadata.is_redacted()
+            ):
+                # Redaction was allowed
+                pruned_json = encode_json(
+                    prune_event_dict(
+                        original_event.room_version, original_event.get_dict()
+                    )
+                )
+            else:
+                # Redaction wasn't allowed
+                pruned_json = None
+
+            updates.append((redaction_id, event_id, pruned_json))
+
+        def _update_censor_txn(txn):
+            for redaction_id, event_id, pruned_json in updates:
+                if pruned_json:
+                    self._censor_event_txn(txn, event_id, pruned_json)
+
+                self.db.simple_update_one_txn(
+                    txn,
+                    table="redactions",
+                    keyvalues={"event_id": redaction_id},
+                    updatevalues={"have_censored": True},
+                )
+
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+    def _censor_event_txn(self, txn, event_id, pruned_json):
+        """Censor an event by replacing its JSON in the event_json table with the
+        provided pruned JSON.
+
+        Args:
+            txn (LoggingTransaction): The database transaction.
+            event_id (str): The ID of the event to censor.
+            pruned_json (str): The pruned JSON
+        """
+        self.db.simple_update_one_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event_id},
+            updatevalues={"json": pruned_json},
+        )
+
+    @defer.inlineCallbacks
+    def expire_event(self, event_id):
+        """Retrieve and expire an event that has expired, and delete its associated
+        expiry timestamp. If the event can't be retrieved, delete its associated
+        timestamp so we don't try to expire it again in the future.
+
+        Args:
+             event_id (str): The ID of the event to delete.
+        """
+        # Try to retrieve the event's content from the database or the event cache.
+        event = yield self.get_event(event_id)
+
+        def delete_expired_event_txn(txn):
+            # Delete the expiry timestamp associated with this event from the database.
+            self._delete_event_expiry_txn(txn, event_id)
+
+            if not event:
+                # If we can't find the event, log a warning and delete the expiry date
+                # from the database so that we don't try to expire it again in the
+                # future.
+                logger.warning(
+                    "Can't expire event %s because we don't have it.", event_id
+                )
+                return
+
+            # Prune the event's dict then convert it to JSON.
+            pruned_json = encode_json(
+                prune_event_dict(event.room_version, event.get_dict())
+            )
+
+            # Update the event_json table to replace the event's JSON with the pruned
+            # JSON.
+            self._censor_event_txn(txn, event.event_id, pruned_json)
+
+            # We need to invalidate the event cache entry for this event because we
+            # changed its content in the database. We can't call
+            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+            # right type.
+            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            # Send that invalidation to replication so that other workers also invalidate
+            # the event cache.
+            self._send_invalidation_to_replication(
+                txn, "_get_event_cache", (event.event_id,)
+            )
+
+        yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+    def _delete_event_expiry_txn(self, txn, event_id):
+        """Delete the expiry timestamp associated with an event ID without deleting the
+        actual event.
+
+        Args:
+            txn (LoggingTransaction): The transaction to use to perform the deletion.
+            event_id (str): The event ID to delete the associated expiry timestamp of.
+        """
+        return self.db.simple_delete_txn(
+            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        )
diff --git a/synapse/storage/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 6db8c54077..71f8d43a76 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -19,11 +19,10 @@ from six import iteritems
 
 from twisted.internet import defer
 
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.caches import CACHE_SIZE_FACTOR
-
-from . import background_updates
-from ._base import Cache
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database, make_tuple_comparison_clause
+from synapse.util.caches.descriptors import Cache
 
 logger = logging.getLogger(__name__)
 
@@ -33,46 +32,41 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-class ClientIpStore(background_updates.BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-
-        self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
-        )
-
-        super(ClientIpStore, self).__init__(db_conn, hs)
+class ClientIpBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_device_index",
             index_name="user_ips_device_id",
             table="user_ips",
             columns=["user_id", "device_id", "last_seen"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_last_seen_index",
             index_name="user_ips_last_seen",
             table="user_ips",
             columns=["user_id", "last_seen"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_last_seen_only_index",
             index_name="user_ips_last_seen_only",
             table="user_ips",
             columns=["last_seen"],
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_analyze", self._analyze_user_ip
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_remove_dupes", self._remove_user_ip_dupes
         )
 
         # Register a unique index
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_device_unique_index",
             index_name="user_ips_user_token_ip_unique_index",
             table="user_ips",
@@ -81,18 +75,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         )
 
         # Drop the old non-unique index
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
         )
 
-        # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-        self._batch_row_update = {}
-
-        self._client_ip_looper = self._clock.looping_call(
-            self._update_client_ips_batch, 5 * 1000
-        )
-        self.hs.get_reactor().addSystemEventTrigger(
-            "before", "shutdown", self._update_client_ips_batch
+        # Update the last seen info in devices.
+        self.db.updates.register_background_update_handler(
+            "devices_last_seen", self._devices_last_seen_update
         )
 
     @defer.inlineCallbacks
@@ -102,8 +91,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
 
-        yield self.runWithConnection(f)
-        yield self._end_background_update("user_ips_drop_nonunique_index")
+        yield self.db.runWithConnection(f)
+        yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
         return 1
 
     @defer.inlineCallbacks
@@ -117,9 +106,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         def user_ips_analyze(txn):
             txn.execute("ANALYZE user_ips")
 
-        yield self.runInteraction("user_ips_analyze", user_ips_analyze)
+        yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
 
-        yield self._end_background_update("user_ips_analyze")
+        yield self.db.updates._end_background_update("user_ips_analyze")
 
         return 1
 
@@ -151,7 +140,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 return None
 
         # Get a last seen that has roughly `batch_size` since `begin_last_seen`
-        end_last_seen = yield self.runInteraction(
+        end_last_seen = yield self.db.runInteraction(
             "user_ips_dups_get_last_seen", get_last_seen
         )
 
@@ -282,18 +271,116 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                     (user_id, access_token, ip, device_id, user_agent, last_seen),
                 )
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
             )
 
-        yield self.runInteraction("user_ips_dups_remove", remove)
+        yield self.db.runInteraction("user_ips_dups_remove", remove)
 
         if last:
-            yield self._end_background_update("user_ips_remove_dupes")
+            yield self.db.updates._end_background_update("user_ips_remove_dupes")
 
         return batch_size
 
     @defer.inlineCallbacks
+    def _devices_last_seen_update(self, progress, batch_size):
+        """Background update to insert last seen info into devices table
+        """
+
+        last_user_id = progress.get("last_user_id", "")
+        last_device_id = progress.get("last_device_id", "")
+
+        def _devices_last_seen_update_txn(txn):
+            # This consists of two queries:
+            #
+            #   1. The sub-query searches for the next N devices and joins
+            #      against user_ips to find the max last_seen associated with
+            #      that device.
+            #   2. The outer query then joins again against user_ips on
+            #      user/device/last_seen. This *should* hopefully only
+            #      return one row, but if it does return more than one then
+            #      we'll just end up updating the same device row multiple
+            #      times, which is fine.
+
+            where_clause, where_args = make_tuple_comparison_clause(
+                self.database_engine,
+                [("user_id", last_user_id), ("device_id", last_device_id)],
+            )
+
+            sql = """
+                SELECT
+                    last_seen, ip, user_agent, user_id, device_id
+                FROM (
+                    SELECT
+                        user_id, device_id, MAX(u.last_seen) AS last_seen
+                    FROM devices
+                    INNER JOIN user_ips AS u USING (user_id, device_id)
+                    WHERE %(where_clause)s
+                    GROUP BY user_id, device_id
+                    ORDER BY user_id ASC, device_id ASC
+                    LIMIT ?
+                ) c
+                INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
+            """ % {
+                "where_clause": where_clause
+            }
+            txn.execute(sql, where_args + [batch_size])
+
+            rows = txn.fetchall()
+            if not rows:
+                return 0
+
+            sql = """
+                UPDATE devices
+                SET last_seen = ?, ip = ?, user_agent = ?
+                WHERE user_id = ? AND device_id = ?
+            """
+            txn.execute_batch(sql, rows)
+
+            _, _, _, user_id, device_id = rows[-1]
+            self.db.updates._background_update_progress_txn(
+                txn,
+                "devices_last_seen",
+                {"last_user_id": user_id, "last_device_id": device_id},
+            )
+
+            return len(rows)
+
+        updated = yield self.db.runInteraction(
+            "_devices_last_seen_update", _devices_last_seen_update_txn
+        )
+
+        if not updated:
+            yield self.db.updates._end_background_update("devices_last_seen")
+
+        return updated
+
+
+class ClientIpStore(ClientIpBackgroundUpdateStore):
+    def __init__(self, database: Database, db_conn, hs):
+
+        self.client_ip_last_seen = Cache(
+            name="client_ip_last_seen", keylen=4, max_entries=50000
+        )
+
+        super(ClientIpStore, self).__init__(database, db_conn, hs)
+
+        self.user_ips_max_age = hs.config.user_ips_max_age
+
+        # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+        self._batch_row_update = {}
+
+        self._client_ip_looper = self._clock.looping_call(
+            self._update_client_ips_batch, 5 * 1000
+        )
+        self.hs.get_reactor().addSystemEventTrigger(
+            "before", "shutdown", self._update_client_ips_batch
+        )
+
+        if self.user_ips_max_age:
+            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+    @defer.inlineCallbacks
     def insert_client_ip(
         self, user_id, access_token, ip, user_agent, device_id, now=None
     ):
@@ -314,23 +401,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         self._batch_row_update[key] = (user_agent, device_id, now)
 
+    @wrap_as_background_process("update_client_ips")
     def _update_client_ips_batch(self):
 
         # If the DB pool has already terminated, don't try updating
-        if not self.hs.get_db_pool().running:
+        if not self.db.is_running():
             return
 
-        def update():
-            to_update = self._batch_row_update
-            self._batch_row_update = {}
-            return self.runInteraction(
-                "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
-            )
+        to_update = self._batch_row_update
+        self._batch_row_update = {}
 
-        return run_as_background_process("update_client_ips", update)
+        return self.db.runInteraction(
+            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+        )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
-        if "user_ips" in self._unsafe_to_upsert_tables or (
+        if "user_ips" in self.db._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
         ):
             self.database_engine.lock_table(txn, "user_ips")
@@ -339,7 +425,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
             try:
-                self._simple_upsert_txn(
+                self.db.simple_upsert_txn(
                     txn,
                     table="user_ips",
                     keyvalues={
@@ -354,6 +440,23 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                     },
                     lock=False,
                 )
+
+                # Technically an access token might not be associated with
+                # a device so we need to check.
+                if device_id:
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
+                        txn,
+                        table="devices",
+                        keyvalues={"user_id": user_id, "device_id": device_id},
+                        updatevalues={
+                            "user_agent": user_agent,
+                            "last_seen": last_seen,
+                            "ip": ip,
+                        },
+                    )
             except Exception as e:
                 # Failed to upsert, log and continue
                 logger.error("Failed to insert client IP %r: %r", entry, e)
@@ -372,19 +475,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             keys giving the column names
         """
 
-        res = yield self.runInteraction(
-            "get_last_client_ip_by_device",
-            self._get_last_client_ip_by_device_txn,
-            user_id,
-            device_id,
-            retcols=(
-                "user_id",
-                "access_token",
-                "ip",
-                "user_agent",
-                "device_id",
-                "last_seen",
-            ),
+        keyvalues = {"user_id": user_id}
+        if device_id is not None:
+            keyvalues["device_id"] = device_id
+
+        res = yield self.db.simple_select_list(
+            table="devices",
+            keyvalues=keyvalues,
+            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
         )
 
         ret = {(d["user_id"], d["device_id"]): d for d in res}
@@ -403,42 +501,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                     }
         return ret
 
-    @classmethod
-    def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
-        where_clauses = []
-        bindings = []
-        if device_id is None:
-            where_clauses.append("user_id = ?")
-            bindings.extend((user_id,))
-        else:
-            where_clauses.append("(user_id = ? AND device_id = ?)")
-            bindings.extend((user_id, device_id))
-
-        if not where_clauses:
-            return []
-
-        inner_select = (
-            "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
-            "WHERE %(where)s "
-            "GROUP BY user_id, device_id"
-        ) % {"where": " OR ".join(where_clauses)}
-
-        sql = (
-            "SELECT %(retcols)s FROM user_ips "
-            "JOIN (%(inner_select)s) ips ON"
-            "    user_ips.last_seen = ips.mls AND"
-            "    user_ips.user_id = ips.user_id AND"
-            "    (user_ips.device_id = ips.device_id OR"
-            "         (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
-            "    )"
-        ) % {
-            "retcols": ",".join("user_ips." + c for c in retcols),
-            "inner_select": inner_select,
-        }
-
-        txn.execute(sql, bindings)
-        return cls.cursor_to_dict(txn)
-
     @defer.inlineCallbacks
     def get_user_ip_and_agents(self, user):
         user_id = user.to_string()
@@ -450,7 +512,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 user_agent, _, last_seen = self._batch_row_update[key]
                 results[(access_token, ip)] = (user_agent, last_seen)
 
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="user_ips",
             keyvalues={"user_id": user_id},
             retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -461,7 +523,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
             for row in rows
         )
-        return list(
+        return [
             {
                 "access_token": access_token,
                 "ip": ip,
@@ -469,4 +531,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 "last_seen": last_seen,
             }
             for (access_token, ip), (user_agent, last_seen) in iteritems(results)
-        )
+        ]
+
+    @wrap_as_background_process("prune_old_user_ips")
+    async def _prune_old_user_ips(self):
+        """Removes entries in user IPs older than the configured period.
+        """
+
+        if self.user_ips_max_age is None:
+            # Nothing to do
+            return
+
+        if not await self.db.updates.has_completed_background_update(
+            "devices_last_seen"
+        ):
+            # Only start pruning if we have finished populating the devices
+            # last seen info.
+            return
+
+        # We do a slightly funky SQL delete to ensure we don't try and delete
+        # too much at once (as the table may be very large from before we
+        # started pruning).
+        #
+        # This works by finding the max last_seen that is less than the given
+        # time, but has no more than N rows before it, deleting all rows with
+        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+        # returns exactly one row).
+        sql = """
+            DELETE FROM user_ips
+            WHERE last_seen <= (
+                SELECT COALESCE(MAX(last_seen), -1)
+                FROM (
+                    SELECT last_seen FROM user_ips
+                    WHERE last_seen <= ?
+                    ORDER BY last_seen ASC
+                    LIMIT 5000
+                ) AS u
+            )
+        """
+
+        timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+        def _prune_old_user_ips_txn(txn):
+            txn.execute(sql, (timestamp,))
+
+        await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 6b7458304e..9a1178fb39 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -20,8 +20,8 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 logger = logging.getLogger(__name__)
@@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_new_messages_for_device", get_new_messages_for_device_txn
         )
 
@@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        count = yield self.runInteraction(
+        count = yield self.db.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
@@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_new_device_msgs_for_remote",
             get_new_messages_for_remote_destination_txn,
         )
@@ -203,28 +203,92 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (destination, up_to_stream_id))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
+    def get_all_new_device_messages(self, last_pos, current_pos, limit):
+        """
+        Args:
+            last_pos(int):
+            current_pos(int):
+            limit(int):
+        Returns:
+            A deferred list of rows from the device inbox
+        """
+        if last_pos == current_pos:
+            return defer.succeed([])
 
-class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
+        def get_all_new_device_messages_txn(txn):
+            # We limit like this as we might have multiple rows per stream_id, and
+            # we want to make sure we always get all entries for any stream_id
+            # we return.
+            upper_pos = min(current_pos, last_pos + limit)
+            sql = (
+                "SELECT max(stream_id), user_id"
+                " FROM device_inbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY user_id"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows = txn.fetchall()
+
+            sql = (
+                "SELECT max(stream_id), destination"
+                " FROM device_federation_outbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY destination"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows.extend(txn)
+
+            # Order by ascending stream ordering
+            rows.sort()
+
+            return rows
+
+        return self.db.runInteraction(
+            "get_all_new_device_messages", get_all_new_device_messages_txn
+        )
+
+
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_inbox_stream_index",
             index_name="device_inbox_stream_id_user_id",
             table="device_inbox",
             columns=["stream_id", "user_id"],
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
+    @defer.inlineCallbacks
+    def _background_drop_index_device_inbox(self, progress, batch_size):
+        def reindex_txn(conn):
+            txn = conn.cursor()
+            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+            txn.close()
+
+        yield self.db.runWithConnection(reindex_txn)
+
+        yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+        return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
         self._last_device_delete_cache = ExpiringCache(
@@ -274,7 +338,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
             )
             for user_id in local_messages_by_user_then_device.keys():
@@ -294,7 +358,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
             # acknowledgement from the first time we received the message.
-            already_inserted = self._simple_select_one_txn(
+            already_inserted = self.db.simple_select_one_txn(
                 txn,
                 table="device_federation_inbox",
                 keyvalues={"origin": origin, "message_id": message_id},
@@ -306,7 +370,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
 
             # Add an entry for this message_id so that we know we've processed
             # it.
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="device_federation_inbox",
                 values={
@@ -324,7 +388,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,
                 now_ms,
@@ -347,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
-                sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
+                sql = "SELECT device_id FROM devices WHERE user_id = ?"
                 txn.execute(sql, (user_id,))
                 message_json = json.dumps(messages_by_device["*"])
                 for row in txn:
@@ -358,15 +422,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
             else:
                 if not devices:
                     continue
-                sql = (
-                    "SELECT device_id FROM devices"
-                    " WHERE user_id = ? AND device_id IN ("
-                    + ",".join("?" * len(devices))
-                    + ")"
+
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "device_id", devices
                 )
+                sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
+
                 # TODO: Maybe this needs to be done in batches if there are
                 # too many local devices for a given user.
-                txn.execute(sql, [user_id] + devices)
+                txn.execute(sql, [user_id] + list(args))
                 for row in txn:
                     # Only insert into the local inbox if the device exists on
                     # this server
@@ -391,60 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
                 rows.append((user_id, device_id, stream_id, message_json))
 
         txn.executemany(sql, rows)
-
-    def get_all_new_device_messages(self, last_pos, current_pos, limit):
-        """
-        Args:
-            last_pos(int):
-            current_pos(int):
-            limit(int):
-        Returns:
-            A deferred list of rows from the device inbox
-        """
-        if last_pos == current_pos:
-            return defer.succeed([])
-
-        def get_all_new_device_messages_txn(txn):
-            # We limit like this as we might have multiple rows per stream_id, and
-            # we want to make sure we always get all entries for any stream_id
-            # we return.
-            upper_pos = min(current_pos, last_pos + limit)
-            sql = (
-                "SELECT max(stream_id), user_id"
-                " FROM device_inbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY user_id"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows = txn.fetchall()
-
-            sql = (
-                "SELECT max(stream_id), destination"
-                " FROM device_federation_outbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY destination"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn)
-
-            # Order by ascending stream ordering
-            rows.sort()
-
-            return rows
-
-        return self.runInteraction(
-            "get_all_new_device_messages", get_all_new_device_messages_txn
-        )
-
-    @defer.inlineCallbacks
-    def _background_drop_index_device_inbox(self, progress, batch_size):
-        def reindex_txn(conn):
-            txn = conn.cursor()
-            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
-            txn.close()
-
-        yield self.runWithConnection(reindex_txn)
-
-        yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
-        return 1
diff --git a/synapse/storage/devices.py b/synapse/storage/data_stores/main/devices.py
index 79a58df591..fb9f798e29 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Optional, Set, Tuple
 
 from six import iteritems
 
@@ -20,7 +23,7 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
     get_active_span_text_map,
     set_tag,
@@ -28,10 +31,21 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import Cache, SQLBaseStore, db_to_json
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import (
+    Database,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
+from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.util.caches.descriptors import (
+    Cache,
+    cached,
+    cachedInlineCallbacks,
+    cachedList,
+)
+from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
 
 logger = logging.getLogger(__name__)
 
@@ -39,10 +53,13 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
     "drop_device_list_streams_non_unique_indexes"
 )
 
+BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+
 
 class DeviceWorkerStore(SQLBaseStore):
     def get_device(self, user_id, device_id):
-        """Retrieve a device.
+        """Retrieve a device. Only returns devices that are not marked as
+        hidden.
 
         Args:
             user_id (str): The ID of the user which owns the device
@@ -52,16 +69,17 @@ class DeviceWorkerStore(SQLBaseStore):
         Raises:
             StoreError: if the device is not found
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_device",
         )
 
     @defer.inlineCallbacks
     def get_devices_by_user(self, user_id):
-        """Retrieve all of a user's registered devices.
+        """Retrieve all of a user's registered devices. Only returns devices
+        that are not marked as hidden.
 
         Args:
             user_id (str):
@@ -70,9 +88,9 @@ class DeviceWorkerStore(SQLBaseStore):
             containing "device_id", "user_id" and "display_name" for each
             device.
         """
-        devices = yield self._simple_select_list(
+        devices = yield self.db.simple_select_list(
             table="devices",
-            keyvalues={"user_id": user_id},
+            keyvalues={"user_id": user_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_devices_by_user",
         )
@@ -81,13 +99,18 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @trace
     @defer.inlineCallbacks
-    def get_devices_by_remote(self, destination, from_stream_id, limit):
-        """Get stream of updates to send to remote servers
+    def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+        """Get a stream of device updates to send to the given remote server.
 
+        Args:
+            destination (str): The host the device updates are intended for
+            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+            limit (int): Maximum number of device updates to return
         Returns:
-            Deferred[tuple[int, list[dict]]]:
+            Deferred[tuple[int, list[tuple[string,dict]]]]:
                 current stream id (ie, the stream id of the last update included in the
-                response), and the list of updates
+                response), and the list of updates, where each update is a pair of EDU
+                type and EDU contents
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -97,36 +120,49 @@ class DeviceWorkerStore(SQLBaseStore):
         if not has_changed:
             return now_stream_id, []
 
-        # We retrieve n+1 devices from the list of outbound pokes where n is
-        # our outbound device update limit. We then check if the very last
-        # device has the same stream_id as the second-to-last device. If so,
-        # then we ignore all devices with that stream_id and only send the
-        # devices with a lower stream_id.
-        #
-        # If when culling the list we end up with no devices afterwards, we
-        # consider the device update to be too large, and simply skip the
-        # stream_id; the rationale being that such a large device list update
-        # is likely an error.
-        updates = yield self.runInteraction(
-            "get_devices_by_remote",
-            self._get_devices_by_remote_txn,
+        updates = yield self.db.runInteraction(
+            "get_device_updates_by_remote",
+            self._get_device_updates_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
-            limit + 1,
+            limit,
         )
 
         # Return an empty list if there are no updates
         if not updates:
             return now_stream_id, []
 
-        # if we have exceeded the limit, we need to exclude any results with the
-        # same stream_id as the last row.
-        if len(updates) > limit:
-            stream_id_cutoff = updates[-1][2]
-            now_stream_id = stream_id_cutoff - 1
-        else:
-            stream_id_cutoff = None
+        # get the cross-signing keys of the users in the list, so that we can
+        # determine which of the device changes were cross-signing keys
+        users = {r[0] for r in updates}
+        master_key_by_user = {}
+        self_signing_key_by_user = {}
+        for user in users:
+            cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                # verify_key is a VerifyKey from signedjson, which uses
+                # .version to denote the portion of the key ID after the
+                # algorithm and colon, which is the device ID
+                master_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "device_id": verify_key.version,
+                }
+
+            cross_signing_key = yield self.get_e2e_cross_signing_key(
+                user, "self_signing"
+            )
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                self_signing_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "device_id": verify_key.version,
+                }
 
         # Perform the equivalent of a GROUP BY
         #
@@ -135,7 +171,6 @@ class DeviceWorkerStore(SQLBaseStore):
         # the max stream_id across each set of duplicate entries
         #
         # maps (user_id, device_id) -> (stream_id, opentracing_context)
-        # as long as their stream_id does not match that of the last row
         #
         # opentracing_context contains the opentracing metadata for the request
         # that created the poke
@@ -144,39 +179,43 @@ class DeviceWorkerStore(SQLBaseStore):
         # context which created the Edu.
 
         query_map = {}
-        for update in updates:
-            if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
-                # Stop processing updates
-                break
-
-            key = (update[0], update[1])
-
-            update_context = update[3]
-            update_stream_id = update[2]
-
-            previous_update_stream_id, _ = query_map.get(key, (0, None))
-
-            if update_stream_id > previous_update_stream_id:
-                query_map[key] = (update_stream_id, update_context)
+        cross_signing_keys_by_user = {}
+        for user_id, device_id, update_stream_id, update_context in updates:
+            if (
+                user_id in master_key_by_user
+                and device_id == master_key_by_user[user_id]["device_id"]
+            ):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["master_key"] = master_key_by_user[user_id]["key_info"]
+            elif (
+                user_id in self_signing_key_by_user
+                and device_id == self_signing_key_by_user[user_id]["device_id"]
+            ):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["self_signing_key"] = self_signing_key_by_user[user_id][
+                    "key_info"
+                ]
+            else:
+                key = (user_id, device_id)
 
-        # If we didn't find any updates with a stream_id lower than the cutoff, it
-        # means that there are more than limit updates all of which have the same
-        # steam_id.
+                previous_update_stream_id, _ = query_map.get(key, (0, None))
 
-        # That should only happen if a client is spamming the server with new
-        # devices, in which case E2E isn't going to work well anyway. We'll just
-        # skip that stream_id and return an empty list, and continue with the next
-        # stream_id next time.
-        if not query_map:
-            return stream_id_cutoff, []
+                if update_stream_id > previous_update_stream_id:
+                    query_map[key] = (update_stream_id, update_context)
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
 
+        # add the updated cross-signing keys to the results list
+        for user_id, result in iteritems(cross_signing_keys_by_user):
+            result["user_id"] = user_id
+            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            results.append(("org.matrix.signing_key_update", result))
+
         return now_stream_id, results
 
-    def _get_devices_by_remote_txn(
+    def _get_device_updates_by_remote_txn(
         self, txn, destination, from_stream_id, now_stream_id, limit
     ):
         """Return device update information for a given remote destination
@@ -191,13 +230,14 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             List: List of device updates
         """
+        # get the list of device updates that need to be sent
         sql = """
             SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
-            WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+            WHERE destination = ? AND ? < stream_id AND stream_id <= ?
             ORDER BY stream_id
             LIMIT ?
         """
-        txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
+        txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
 
         return list(txn)
 
@@ -216,12 +256,16 @@ class DeviceWorkerStore(SQLBaseStore):
             List[Dict]: List of objects representing an device update EDU
 
         """
-        devices = yield self.runInteraction(
-            "_get_e2e_device_keys_txn",
-            self._get_e2e_device_keys_txn,
-            query_map.keys(),
-            include_all_devices=True,
-            include_deleted_devices=True,
+        devices = (
+            yield self.db.runInteraction(
+                "_get_e2e_device_keys_txn",
+                self._get_e2e_device_keys_txn,
+                query_map.keys(),
+                include_all_devices=True,
+                include_deleted_devices=True,
+            )
+            if query_map
+            else {}
         )
 
         results = []
@@ -231,7 +275,14 @@ class DeviceWorkerStore(SQLBaseStore):
             prev_id = yield self._get_last_device_update_for_remote_user(
                 destination, user_id, from_stream_id
             )
-            for device_id, device in iteritems(user_devices):
+
+            # make sure we go through the devices in stream order
+            device_ids = sorted(
+                user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+            )
+
+            for device_id in device_ids:
+                device = user_devices[device_id]
                 stream_id, opentracing_context = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
@@ -247,13 +298,20 @@ class DeviceWorkerStore(SQLBaseStore):
                     key_json = device.get("key_json", None)
                     if key_json:
                         result["keys"] = db_to_json(key_json)
+
+                        if "signatures" in device:
+                            for sig_user_id, sigs in device["signatures"].items():
+                                result["keys"].setdefault("signatures", {}).setdefault(
+                                    sig_user_id, {}
+                                ).update(sigs)
+
                     device_display_name = device.get("device_display_name", None)
                     if device_display_name:
                         result["device_display_name"] = device_display_name
                 else:
                     result["deleted"] = True
 
-                results.append(result)
+                results.append(("m.device_list_update", result))
 
         return results
 
@@ -270,12 +328,12 @@ class DeviceWorkerStore(SQLBaseStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        return self.runInteraction("get_last_device_update_for_remote_user", f)
+        return self.db.runInteraction("get_last_device_update_for_remote_user", f)
 
     def mark_as_sent_devices_by_remote(self, destination, stream_id):
         """Mark that updates have successfully been sent to the destination.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
             destination,
@@ -284,32 +342,23 @@ class DeviceWorkerStore(SQLBaseStore):
 
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
         # We update the device_lists_outbound_last_success with the successfully
-        # poked users. We do the join to see which users need to be inserted and
-        # which updated.
+        # poked users.
         sql = """
-            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+            SELECT user_id, coalesce(max(o.stream_id), 0)
             FROM device_lists_outbound_pokes as o
-            LEFT JOIN device_lists_outbound_last_success as s
-                USING (destination, user_id)
             WHERE destination = ? AND o.stream_id <= ?
             GROUP BY user_id
         """
         txn.execute(sql, (destination, stream_id))
         rows = txn.fetchall()
 
-        sql = """
-            UPDATE device_lists_outbound_last_success
-            SET stream_id = ?
-            WHERE destination = ? AND user_id = ?
-        """
-        txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
-        sql = """
-            INSERT INTO device_lists_outbound_last_success
-            (destination, user_id, stream_id) VALUES (?, ?, ?)
-        """
-        txn.executemany(
-            sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+        self.db.simple_upsert_many_txn(
+            txn=txn,
+            table="device_lists_outbound_last_success",
+            key_names=("destination", "user_id"),
+            key_values=((destination, user_id) for user_id, _ in rows),
+            value_names=("stream_id",),
+            value_values=((stream_id,) for _, stream_id in rows),
         )
 
         # Delete all sent outbound pokes
@@ -319,6 +368,41 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, stream_id))
 
+    @defer.inlineCallbacks
+    def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+        """Persist that a user has made new signatures
+
+        Args:
+            from_user_id (str): the user who made the signatures
+            user_ids (list[str]): the users who were signed
+        """
+
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.db.runInteraction(
+                "add_user_sig_change_to_streams",
+                self._add_user_signature_change_txn,
+                from_user_id,
+                user_ids,
+                stream_id,
+            )
+        return stream_id
+
+    def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+        txn.call_after(
+            self._user_signature_stream_cache.entity_has_changed,
+            from_user_id,
+            stream_id,
+        )
+        self.db.simple_insert_txn(
+            txn,
+            "user_signature_stream",
+            values={
+                "stream_id": stream_id,
+                "from_user_id": from_user_id,
+                "user_ids": json.dumps(user_ids),
+            },
+        )
+
     def get_device_stream_token(self):
         return self._device_list_id_gen.get_current_token()
 
@@ -336,11 +420,17 @@ class DeviceWorkerStore(SQLBaseStore):
             a set of user_ids and results_map is a mapping of
             user_id -> device_id -> device_info
         """
-        user_ids = set(user_id for user_id, _ in query_list)
+        user_ids = {user_id for user_id, _ in query_list}
         user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
-        user_ids_in_cache = set(
-            user_id for user_id, stream_id in user_map.items() if stream_id
+
+        # We go and check if any of the users need to have their device lists
+        # resynced. If they do then we remove them from the cached list.
+        users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+            user_ids
         )
+        user_ids_in_cache = {
+            user_id for user_id, stream_id in user_map.items() if stream_id
+        } - users_needing_resync
         user_ids_not_in_cache = user_ids - user_ids_in_cache
 
         results = {}
@@ -352,7 +442,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 device = yield self._get_cached_user_device(user_id, device_id)
                 results.setdefault(user_id, {})[device_id] = device
             else:
-                results[user_id] = yield self._get_cached_devices_for_user(user_id)
+                results[user_id] = yield self.get_cached_devices_for_user(user_id)
 
         set_tag("in_cache", results)
         set_tag("not_in_cache", user_ids_not_in_cache)
@@ -361,7 +451,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=2, tree=True)
     def _get_cached_user_device(self, user_id, device_id):
-        content = yield self._simple_select_one_onecol(
+        content = yield self.db.simple_select_one_onecol(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id, "device_id": device_id},
             retcol="content",
@@ -370,12 +460,12 @@ class DeviceWorkerStore(SQLBaseStore):
         return db_to_json(content)
 
     @cachedInlineCallbacks()
-    def _get_cached_devices_for_user(self, user_id):
-        devices = yield self._simple_select_list(
+    def get_cached_devices_for_user(self, user_id):
+        devices = yield self.db.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
             retcols=("device_id", "content"),
-            desc="_get_cached_devices_for_user",
+            desc="get_cached_devices_for_user",
         )
         return {
             device["device_id"]: db_to_json(device["content"]) for device in devices
@@ -387,7 +477,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             (stream_id, devices)
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_devices_with_keys_by_user",
             self._get_devices_with_keys_by_user_txn,
             user_id,
@@ -409,6 +499,13 @@ class DeviceWorkerStore(SQLBaseStore):
                 key_json = device.get("key_json", None)
                 if key_json:
                     result["keys"] = db_to_json(key_json)
+
+                    if "signatures" in device:
+                        for sig_user_id, sigs in device["signatures"].items():
+                            result["keys"].setdefault("signatures", {}).setdefault(
+                                sig_user_id, {}
+                            ).update(sigs)
+
                 device_display_name = device.get("device_display_name", None)
                 if device_display_name:
                     result["device_display_name"] = device_display_name
@@ -435,8 +532,8 @@ class DeviceWorkerStore(SQLBaseStore):
 
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
-        to_check = list(
-            self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+        to_check = self._device_list_stream_cache.get_entities_changed(
+            user_ids, from_key
         )
 
         if not to_check:
@@ -448,35 +545,71 @@ class DeviceWorkerStore(SQLBaseStore):
             sql = """
                 SELECT DISTINCT user_id FROM device_lists_stream
                 WHERE stream_id > ?
-                AND user_id IN (%s)
+                AND
             """
 
             for chunk in batch_iter(to_check, 100):
-                txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk)
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "user_id", chunk
+                )
+                txn.execute(sql + clause, (from_key,) + tuple(args))
                 changes.update(user_id for user_id, in txn)
 
             return changes
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
-    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
-        """Return a list of `(stream_id, user_id, destination)` which is the
-        combined list of changes to devices, and which destinations need to be
-        poked. `destination` may be None if no destinations need to be poked.
+    @defer.inlineCallbacks
+    def get_users_whose_signatures_changed(self, user_id, from_key):
+        """Get the users who have new cross-signing signatures made by `user_id` since
+        `from_key`.
+
+        Args:
+            user_id (str): the user who made the signatures
+            from_key (str): The device lists stream token
         """
-        # We do a group by here as there can be a large number of duplicate
-        # entries, since we throw away device IDs.
+        from_key = int(from_key)
+        if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+            sql = """
+                SELECT DISTINCT user_ids FROM user_signature_stream
+                WHERE from_user_id = ? AND stream_id > ?
+            """
+            rows = yield self.db.execute(
+                "get_users_whose_signatures_changed", None, sql, user_id, from_key
+            )
+            return {user for row in rows for user in json.loads(row[0])}
+        else:
+            return set()
+
+    async def get_all_device_list_changes_for_remotes(
+        self, from_key: int, to_key: int, limit: int,
+    ) -> List[Tuple[int, str]]:
+        """Return a list of `(stream_id, entity)` which is the combined list of
+        changes to devices and which destinations need to be poked. Entity is
+        either a user ID (starting with '@') or a remote destination.
+        """
+
+        # This query Does The Right Thing where it'll correctly apply the
+        # bounds to the inner queries.
         sql = """
-            SELECT MAX(stream_id) AS stream_id, user_id, destination
-            FROM device_lists_stream
-            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            SELECT stream_id, entity FROM (
+                SELECT stream_id, user_id AS entity FROM device_lists_stream
+                UNION ALL
+                SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+            ) AS e
             WHERE ? < stream_id AND stream_id <= ?
-            GROUP BY user_id, destination
+            LIMIT ?
         """
-        return self._execute(
-            "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+
+        return await self.db.execute(
+            "get_all_device_list_changes_for_remotes",
+            None,
+            sql,
+            from_key,
+            to_key,
+            limit,
         )
 
     @cached(max_entries=10000)
@@ -484,7 +617,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             retcol="stream_id",
@@ -498,7 +631,7 @@ class DeviceWorkerStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_device_list_last_stream_id_for_remotes(self, user_ids):
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
             iterable=user_ids,
@@ -511,20 +644,72 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return results
 
+    @defer.inlineCallbacks
+    def get_user_ids_requiring_device_list_resync(
+        self, user_ids: Optional[Collection[str]] = None,
+    ) -> Set[str]:
+        """Given a list of remote users return the list of users that we
+        should resync the device lists for. If None is given instead of a list,
+        return every user that we should resync the device lists for.
 
-class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceStore, self).__init__(db_conn, hs)
+        Returns:
+            The IDs of users whose device lists need resync.
+        """
+        if user_ids:
+            rows = yield self.db.simple_select_many_batch(
+                table="device_lists_remote_resync",
+                column="user_id",
+                iterable=user_ids,
+                retcols=("user_id",),
+                desc="get_user_ids_requiring_device_list_resync_with_iterable",
+            )
+        else:
+            rows = yield self.db.simple_select_list(
+                table="device_lists_remote_resync",
+                keyvalues=None,
+                retcols=("user_id",),
+                desc="get_user_ids_requiring_device_list_resync",
+            )
 
-        # Map of (user_id, device_id) -> bool. If there is an entry that implies
-        # the device exists.
-        self.device_id_exists_cache = Cache(
-            name="device_id_exists", keylen=2, max_entries=10000
+        return {row["user_id"] for row in rows}
+
+    def mark_remote_user_device_cache_as_stale(self, user_id: str):
+        """Records that the server has reason to believe the cache of the devices
+        for the remote users is out of date.
+        """
+        return self.db.simple_upsert(
+            table="device_lists_remote_resync",
+            keyvalues={"user_id": user_id},
+            values={},
+            insertion_values={"added_ts": self._clock.time_msec()},
+            desc="make_remote_user_device_cache_as_stale",
         )
 
-        self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+        """Mark that we no longer track device lists for remote user.
+        """
 
-        self.register_background_index_update(
+        def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+            self.db.simple_delete_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={"user_id": user_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
+            )
+
+        return self.db.runInteraction(
+            "mark_remote_user_device_list_as_unsubscribed",
+            _mark_remote_user_device_list_as_unsubscribed_txn,
+        )
+
+
+class DeviceBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+        self.db.updates.register_background_index_update(
             "device_lists_stream_idx",
             index_name="device_lists_stream_user_id",
             table="device_lists_stream",
@@ -532,7 +717,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         )
 
         # create a unique index on device_lists_remote_cache
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_remote_cache_unique_idx",
             index_name="device_lists_remote_cache_unique_id",
             table="device_lists_remote_cache",
@@ -541,7 +726,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         )
 
         # And one on device_lists_remote_extremeties
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_remote_extremeties_unique_idx",
             index_name="device_lists_remote_extremeties_unique_idx",
             table="device_lists_remote_extremeties",
@@ -550,11 +735,112 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         )
 
         # once they complete, we can remove the old non-unique indexes.
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
             self._drop_device_list_streams_non_unique_indexes,
         )
 
+        # clear out duplicate device list outbound pokes
+        self.db.updates.register_background_update_handler(
+            BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+        )
+
+        # a pair of background updates that were added during the 1.14 release cycle,
+        # but replaced with 58/06dlols_unique_idx.py
+        self.db.updates.register_noop_background_update(
+            "device_lists_outbound_last_success_unique_idx",
+        )
+        self.db.updates.register_noop_background_update(
+            "drop_device_lists_outbound_last_success_non_unique_idx",
+        )
+
+    @defer.inlineCallbacks
+    def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+        def f(conn):
+            txn = conn.cursor()
+            txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+            txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
+            txn.close()
+
+        yield self.db.runWithConnection(f)
+        yield self.db.updates._end_background_update(
+            DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
+        )
+        return 1
+
+    async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+        # for some reason, we have accumulated duplicate entries in
+        # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+        # efficient.
+        #
+        # For each duplicate, we delete all the existing rows and put one back.
+
+        KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
+        last_row = progress.get(
+            "last_row",
+            {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
+        )
+
+        def _txn(txn):
+            clause, args = make_tuple_comparison_clause(
+                self.db.engine, [(x, last_row[x]) for x in KEY_COLS]
+            )
+            sql = """
+                SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
+                FROM device_lists_outbound_pokes
+                WHERE %s
+                GROUP BY %s
+                HAVING count(*) > 1
+                ORDER BY %s
+                LIMIT ?
+                """ % (
+                clause,  # WHERE
+                ",".join(KEY_COLS),  # GROUP BY
+                ",".join(KEY_COLS),  # ORDER BY
+            )
+            txn.execute(sql, args + [batch_size])
+            rows = self.db.cursor_to_dict(txn)
+
+            row = None
+            for row in rows:
+                self.db.simple_delete_txn(
+                    txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+                )
+
+                row["sent"] = False
+                self.db.simple_insert_txn(
+                    txn, "device_lists_outbound_pokes", row,
+                )
+
+            if row:
+                self.db.updates._background_update_progress_txn(
+                    txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+                )
+
+            return len(rows)
+
+        rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn)
+
+        if not rows:
+            await self.db.updates._end_background_update(
+                BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
+            )
+
+        return rows
+
+
+class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceStore, self).__init__(database, db_conn, hs)
+
+        # Map of (user_id, device_id) -> bool. If there is an entry that implies
+        # the device exists.
+        self.device_id_exists_cache = Cache(
+            name="device_id_exists", keylen=2, max_entries=10000
+        )
+
+        self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+
     @defer.inlineCallbacks
     def store_device(self, user_id, device_id, initial_device_display_name):
         """Ensure the given device is known; add it to the store if not
@@ -567,24 +853,39 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         Returns:
             defer.Deferred: boolean whether the device was inserted or an
                 existing device existed with that ID.
+        Raises:
+            StoreError: if the device is already in use
         """
         key = (user_id, device_id)
         if self.device_id_exists_cache.get(key, None):
             return False
 
         try:
-            inserted = yield self._simple_insert(
+            inserted = yield self.db.simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
                     "device_id": device_id,
                     "display_name": initial_device_display_name,
+                    "hidden": False,
                 },
                 desc="store_device",
                 or_ignore=True,
             )
+            if not inserted:
+                # if the device already exists, check if it's a real device, or
+                # if the device ID is reserved by something else
+                hidden = yield self.db.simple_select_one_onecol(
+                    "devices",
+                    keyvalues={"user_id": user_id, "device_id": device_id},
+                    retcol="hidden",
+                )
+                if hidden:
+                    raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
             self.device_id_exists_cache.prefill(key, True)
             return inserted
+        except StoreError:
+            raise
         except Exception as e:
             logger.error(
                 "store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -609,9 +910,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         Returns:
             defer.Deferred
         """
-        yield self._simple_delete_one(
+        yield self.db.simple_delete_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             desc="delete_device",
         )
 
@@ -627,18 +928,19 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         Returns:
             defer.Deferred
         """
-        yield self._simple_delete_many(
+        yield self.db.simple_delete_many(
             table="devices",
             column="device_id",
             iterable=device_ids,
-            keyvalues={"user_id": user_id},
+            keyvalues={"user_id": user_id, "hidden": False},
             desc="delete_devices",
         )
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
 
     def update_device(self, user_id, device_id, new_display_name=None):
-        """Update a device.
+        """Update a device. Only updates the device if it is not marked as
+        hidden.
 
         Args:
             user_id (str): The ID of the user which owns the device
@@ -655,24 +957,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             updates["display_name"] = new_display_name
         if not updates:
             return defer.succeed(None)
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             updatevalues=updates,
             desc="update_device",
         )
 
-    @defer.inlineCallbacks
-    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
-        """Mark that we no longer track device lists for remote user.
-        """
-        yield self._simple_delete(
-            table="device_lists_remote_extremeties",
-            keyvalues={"user_id": user_id},
-            desc="mark_remote_user_device_list_as_unsubscribed",
-        )
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
-
     def update_remote_device_list_cache_entry(
         self, user_id, device_id, content, stream_id
     ):
@@ -690,7 +981,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         Returns:
             Deferred[None]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_remote_device_list_cache_entry",
             self._update_remote_device_list_cache_entry_txn,
             user_id,
@@ -703,7 +994,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         self, txn, user_id, device_id, content, stream_id
     ):
         if content.get("deleted"):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="device_lists_remote_cache",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -711,7 +1002,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
 
             txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
         else:
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="device_lists_remote_cache",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -722,12 +1013,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             )
 
         txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
-        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
         txn.call_after(
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
 
-        self._simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
@@ -751,7 +1042,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         Returns:
             Deferred[None]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_remote_device_list_cache",
             self._update_remote_device_list_cache_txn,
             user_id,
@@ -760,11 +1051,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
         )
 
     def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
-        self._simple_delete_txn(
+        self.db.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_remote_cache",
             values=[
@@ -777,13 +1068,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             ],
         )
 
-        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
         txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
         txn.call_after(
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
 
-        self._simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
@@ -793,34 +1084,61 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             lock=False,
         )
 
+        # If we're replacing the remote user's device list cache presumably
+        # we've done a full resync, so we remove the entry that says we need
+        # to resync
+        self.db.simple_delete_txn(
+            txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+        )
+
     @defer.inlineCallbacks
     def add_device_change_to_streams(self, user_id, device_ids, hosts):
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
         """
-        with self._device_list_id_gen.get_next() as stream_id:
-            yield self.runInteraction(
-                "add_device_change_to_streams",
-                self._add_device_change_txn,
+        if not device_ids:
+            return
+
+        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+            yield self.db.runInteraction(
+                "add_device_change_to_stream",
+                self._add_device_change_to_stream_txn,
+                user_id,
+                device_ids,
+                stream_ids,
+            )
+
+        if not hosts:
+            return stream_ids[-1]
+
+        context = get_active_span_text_map()
+        with self._device_list_id_gen.get_next_mult(
+            len(hosts) * len(device_ids)
+        ) as stream_ids:
+            yield self.db.runInteraction(
+                "add_device_outbound_poke_to_stream",
+                self._add_device_outbound_poke_to_stream_txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_id,
+                stream_ids,
+                context,
             )
-        return stream_id
 
-    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
-        now = self._clock.time_msec()
+        return stream_ids[-1]
 
+    def _add_device_change_to_stream_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Collection[str],
+        stream_ids: List[str],
+    ):
         txn.call_after(
-            self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+            self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
         )
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_id,
-            )
+
+        min_stream_id = stream_ids[0]
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
@@ -829,27 +1147,38 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
             """,
-            [(user_id, device_id, stream_id) for device_id in device_ids],
+            [(user_id, device_id, min_stream_id) for device_id in device_ids],
         )
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_stream",
             values=[
                 {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
-                for device_id in device_ids
+                for stream_id, device_id in zip(stream_ids, device_ids)
             ],
         )
 
-        context = get_active_span_text_map()
+    def _add_device_outbound_poke_to_stream_txn(
+        self, txn, user_id, device_ids, hosts, stream_ids, context,
+    ):
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host,
+                stream_ids[-1],
+            )
 
-        self._simple_insert_many_txn(
+        now = self._clock.time_msec()
+        next_stream_id = iter(stream_ids)
+
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
             values=[
                 {
                     "destination": destination,
-                    "stream_id": stream_id,
+                    "stream_id": next(next_stream_id),
                     "user_id": user_id,
                     "device_id": device_id,
                     "sent": False,
@@ -863,18 +1192,47 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             ],
         )
 
-    def _prune_old_outbound_device_pokes(self):
+    def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
         """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers. We keep one entry per
-        (destination, user_id) tuple to ensure that the prev_ids remain correct
-        if the server does come back.
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
         """
-        yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+        yesterday = self._clock.time_msec() - prune_age
 
         def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
             select_sql = """
-                SELECT destination, user_id, max(stream_id) as stream_id
-                FROM device_lists_outbound_pokes
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
                 GROUP BY destination, user_id
                 HAVING min(ts) < ? AND count(*) > 1
             """
@@ -885,14 +1243,29 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             if not rows:
                 return
 
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
             delete_sql = """
                 DELETE FROM device_lists_outbound_pokes
-                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
             """
-
-            txn.executemany(
-                delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
-            )
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
 
             # Since we've deleted unsent deltas, we need to remove the entry
             # of last successful sent so that the prev_ids are correctly set.
@@ -902,23 +1275,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
             """
             txn.executemany(sql, ((row[0], row[1]) for row in rows))
 
-            logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+            logger.info("Pruned %d device list outbound pokes", count)
 
         return run_as_background_process(
             "prune_old_outbound_device_pokes",
-            self.runInteraction,
+            self.db.runInteraction,
             "_prune_old_outbound_device_pokes",
             _prune_txn,
         )
-
-    @defer.inlineCallbacks
-    def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
-        def f(conn):
-            txn = conn.cursor()
-            txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
-            txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
-            txn.close()
-
-        yield self.runWithConnection(f)
-        yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
-        return 1
diff --git a/synapse/storage/directory.py b/synapse/storage/data_stores/main/directory.py
index eed7757ed5..e1d1bc3e05 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -14,14 +14,14 @@
 # limitations under the License.
 
 from collections import namedtuple
+from typing import Optional
 
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached
 
-from ._base import SQLBaseStore
-
 RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
 
 
@@ -37,7 +37,7 @@ class DirectoryWorkerStore(SQLBaseStore):
             Deferred: results in namedtuple with keys "room_id" and
             "servers" or None if no association can be found
         """
-        room_id = yield self._simple_select_one_onecol(
+        room_id = yield self.db.simple_select_one_onecol(
             "room_aliases",
             {"room_alias": room_alias.to_string()},
             "room_id",
@@ -48,7 +48,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         if not room_id:
             return None
 
-        servers = yield self._simple_select_onecol(
+        servers = yield self.db.simple_select_onecol(
             "room_alias_servers",
             {"room_alias": room_alias.to_string()},
             "server",
@@ -61,7 +61,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
     def get_room_alias_creator(self, room_alias):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="room_aliases",
             keyvalues={"room_alias": room_alias},
             retcol="creator",
@@ -70,7 +70,7 @@ class DirectoryWorkerStore(SQLBaseStore):
 
     @cached(max_entries=5000)
     def get_aliases_for_room(self, room_id):
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
             "room_alias",
@@ -94,7 +94,7 @@ class DirectoryStore(DirectoryWorkerStore):
         """
 
         def alias_txn(txn):
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 "room_aliases",
                 {
@@ -104,7 +104,7 @@ class DirectoryStore(DirectoryWorkerStore):
                 },
             )
 
-            self._simple_insert_many_txn(
+            self.db.simple_insert_many_txn(
                 txn,
                 table="room_alias_servers",
                 values=[
@@ -118,7 +118,9 @@ class DirectoryStore(DirectoryWorkerStore):
             )
 
         try:
-            ret = yield self.runInteraction("create_room_alias_association", alias_txn)
+            ret = yield self.db.runInteraction(
+                "create_room_alias_association", alias_txn
+            )
         except self.database_engine.module.IntegrityError:
             raise SynapseError(
                 409, "Room alias %s already exists" % room_alias.to_string()
@@ -127,7 +129,7 @@ class DirectoryStore(DirectoryWorkerStore):
 
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
-        room_id = yield self.runInteraction(
+        room_id = yield self.db.runInteraction(
             "delete_room_alias", self._delete_room_alias_txn, room_alias
         )
 
@@ -158,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore):
 
         return room_id
 
-    def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+    def update_aliases_for_room(
+        self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+    ):
+        """Repoint all of the aliases for a given room, to a different room.
+
+        Args:
+            old_room_id:
+            new_room_id:
+            creator: The user to record as the creator of the new mapping.
+                If None, the creator will be left unchanged.
+        """
+
         def _update_aliases_for_room_txn(txn):
-            sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
-            txn.execute(sql, (new_room_id, creator, old_room_id))
+            update_creator_sql = ""
+            sql_params = (new_room_id, old_room_id)
+            if creator:
+                update_creator_sql = ", creator = ?"
+                sql_params = (new_room_id, creator, old_room_id)
+
+            sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % (
+                update_creator_sql,
+            )
+            txn.execute(sql, sql_params)
             self._invalidate_cache_and_stream(
                 txn, self.get_aliases_for_room, (old_room_id,)
             )
@@ -169,6 +190,6 @@ class DirectoryStore(DirectoryWorkerStore):
                 txn, self.get_aliases_for_room, (new_room_id,)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_update_aliases_for_room_txn", _update_aliases_for_room_txn
         )
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index be2fe2bab6..23f4570c4b 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,55 +20,13 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 
 class EndToEndRoomKeyStore(SQLBaseStore):
     @defer.inlineCallbacks
-    def get_e2e_room_key(self, user_id, version, room_id, session_id):
-        """Get the encrypted E2E room key for a given session from a given
-        backup version of room_keys.  We only store the 'best' room key for a given
-        session at a given time, as determined by the handler.
-
-        Args:
-            user_id(str): the user whose backup we're querying
-            version(str): the version ID of the backup for the set of keys we're querying
-            room_id(str): the ID of the room whose keys we're querying.
-                This is a bit redundant as it's implied by the session_id, but
-                we include for consistency with the rest of the API.
-            session_id(str): the session whose room_key we're querying.
-
-        Returns:
-            A deferred dict giving the session_data and message metadata for
-            this room key.
-        """
-
-        row = yield self._simple_select_one(
-            table="e2e_room_keys",
-            keyvalues={
-                "user_id": user_id,
-                "version": version,
-                "room_id": room_id,
-                "session_id": session_id,
-            },
-            retcols=(
-                "first_message_index",
-                "forwarded_count",
-                "is_verified",
-                "session_data",
-            ),
-            desc="get_e2e_room_key",
-        )
-
-        row["session_data"] = json.loads(row["session_data"])
-
-        return row
-
-    @defer.inlineCallbacks
-    def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
-        """Replaces or inserts the encrypted E2E room key for a given session in
-        a given backup
+    def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+        """Replaces the encrypted E2E room key for a given session in a given backup
 
         Args:
             user_id(str): the user whose backup we're setting
@@ -79,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             StoreError
         """
 
-        yield self._simple_upsert(
+        yield self.db.simple_update_one(
             table="e2e_room_keys",
             keyvalues={
                 "user_id": user_id,
@@ -87,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 "room_id": room_id,
                 "session_id": session_id,
             },
-            values={
+            updatevalues={
                 "first_message_index": room_key["first_message_index"],
                 "forwarded_count": room_key["forwarded_count"],
                 "is_verified": room_key["is_verified"],
                 "session_data": json.dumps(room_key["session_data"]),
             },
-            lock=False,
+            desc="update_e2e_room_key",
         )
-        log_kv(
-            {
-                "message": "Set room key",
-                "room_id": room_id,
-                "session_id": session_id,
-                "room_key": room_key,
-            }
+
+    @defer.inlineCallbacks
+    def add_e2e_room_keys(self, user_id, version, room_keys):
+        """Bulk add room keys to a given backup.
+
+        Args:
+            user_id (str): the user whose backup we're adding to
+            version (str): the version ID of the backup for the set of keys we're adding to
+            room_keys (iterable[(str, str, dict)]): the keys to add, in the form
+                (roomID, sessionID, keyData)
+        """
+
+        values = []
+        for (room_id, session_id, room_key) in room_keys:
+            values.append(
+                {
+                    "user_id": user_id,
+                    "version": version,
+                    "room_id": room_id,
+                    "session_id": session_id,
+                    "first_message_index": room_key["first_message_index"],
+                    "forwarded_count": room_key["forwarded_count"],
+                    "is_verified": room_key["is_verified"],
+                    "session_data": json.dumps(room_key["session_data"]),
+                }
+            )
+            log_kv(
+                {
+                    "message": "Set room key",
+                    "room_id": room_id,
+                    "session_id": session_id,
+                    "room_key": room_key,
+                }
+            )
+
+        yield self.db.simple_insert_many(
+            table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
         )
 
     @trace
@@ -111,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         room, or a given session.
 
         Args:
-            user_id(str): the user whose backup we're querying
-            version(str): the version ID of the backup for the set of keys we're querying
-            room_id(str): Optional. the ID of the room whose keys we're querying, if any.
+            user_id (str): the user whose backup we're querying
+            version (str): the version ID of the backup for the set of keys we're querying
+            room_id (str): Optional. the ID of the room whose keys we're querying, if any.
                 If not specified, we return the keys for all the rooms in the backup.
-            session_id(str): Optional. the session whose room_key we're querying, if any.
+            session_id (str): Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we return all the keys in this version of
                 the backup (or for the specified room)
@@ -136,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="e2e_room_keys",
             keyvalues=keyvalues,
             retcols=(
@@ -157,12 +146,102 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             room_entry["sessions"][row["session_id"]] = {
                 "first_message_index": row["first_message_index"],
                 "forwarded_count": row["forwarded_count"],
-                "is_verified": row["is_verified"],
+                # is_verified must be returned to the client as a boolean
+                "is_verified": bool(row["is_verified"]),
                 "session_data": json.loads(row["session_data"]),
             }
 
         return sessions
 
+    def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+        """Get multiple room keys at a time.  The difference between this function and
+        get_e2e_room_keys is that this function can be used to retrieve
+        multiple specific keys at a time, whereas get_e2e_room_keys is used for
+        getting all the keys in a backup version, all the keys for a room, or a
+        specific key.
+
+        Args:
+            user_id (str): the user whose backup we're querying
+            version (str): the version ID of the backup we're querying about
+            room_keys (dict[str, dict[str, iterable[str]]]): a map from
+                room ID -> {"session": [session ids]} indicating the session IDs
+                that we want to query
+
+        Returns:
+           Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+        """
+
+        return self.db.runInteraction(
+            "get_e2e_room_keys_multi",
+            self._get_e2e_room_keys_multi_txn,
+            user_id,
+            version,
+            room_keys,
+        )
+
+    @staticmethod
+    def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+        if not room_keys:
+            return {}
+
+        where_clauses = []
+        params = [user_id, version]
+        for room_id, room in room_keys.items():
+            sessions = list(room["sessions"])
+            if not sessions:
+                continue
+            params.append(room_id)
+            params.extend(sessions)
+            where_clauses.append(
+                "(room_id = ? AND session_id IN (%s))"
+                % (",".join(["?" for _ in sessions]),)
+            )
+
+        # check if we're actually querying something
+        if not where_clauses:
+            return {}
+
+        sql = """
+        SELECT room_id, session_id, first_message_index, forwarded_count,
+               is_verified, session_data
+        FROM e2e_room_keys
+        WHERE user_id = ? AND version = ? AND (%s)
+        """ % (
+            " OR ".join(where_clauses)
+        )
+
+        txn.execute(sql, params)
+
+        ret = {}
+
+        for row in txn:
+            room_id = row[0]
+            session_id = row[1]
+            ret.setdefault(room_id, {})
+            ret[room_id][session_id] = {
+                "first_message_index": row[2],
+                "forwarded_count": row[3],
+                "is_verified": row[4],
+                "session_data": json.loads(row[5]),
+            }
+
+        return ret
+
+    def count_e2e_room_keys(self, user_id, version):
+        """Get the number of keys in a backup version.
+
+        Args:
+            user_id (str): the user whose backup we're querying
+            version (str): the version ID of the backup we're querying about
+        """
+
+        return self.db.simple_select_one_onecol(
+            table="e2e_room_keys",
+            keyvalues={"user_id": user_id, "version": version},
+            retcol="COUNT(*)",
+            desc="count_e2e_room_keys",
+        )
+
     @trace
     @defer.inlineCallbacks
     def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
@@ -189,7 +268,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        yield self._simple_delete(
+        yield self.db.simple_delete(
             table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
         )
 
@@ -220,6 +299,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 version(str)
                 algorithm(str)
                 auth_data(object): opaque dict supplied by the client
+                etag(int): tag of the keys in the backup
         """
 
         def _get_e2e_room_keys_version_info_txn(txn):
@@ -233,17 +313,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     # it isn't there.
                     raise StoreError(404, "No row found")
 
-            result = self._simple_select_one_txn(
+            result = self.db.simple_select_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
-                retcols=("version", "algorithm", "auth_data"),
+                retcols=("version", "algorithm", "auth_data", "etag"),
             )
             result["auth_data"] = json.loads(result["auth_data"])
             result["version"] = str(result["version"])
+            if result["etag"] is None:
+                result["etag"] = 0
             return result
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
         )
 
@@ -271,7 +353,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
             new_version = str(int(current_version) + 1)
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 values={
@@ -284,26 +366,38 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
             return new_version
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
         )
 
     @trace
-    def update_e2e_room_keys_version(self, user_id, version, info):
+    def update_e2e_room_keys_version(
+        self, user_id, version, info=None, version_etag=None
+    ):
         """Update a given backup version
 
         Args:
             user_id(str): the user whose backup version we're updating
             version(str): the version ID of the backup version we're updating
-            info(dict): the new backup version info to store
+            info (dict): the new backup version info to store.  If None, then
+                the backup version info is not updated
+            version_etag (Optional[int]): etag of the keys in the backup.  If
+                None, then the etag is not updated
         """
+        updatevalues = {}
 
-        return self._simple_update(
-            table="e2e_room_keys_versions",
-            keyvalues={"user_id": user_id, "version": version},
-            updatevalues={"auth_data": json.dumps(info["auth_data"])},
-            desc="update_e2e_room_keys_version",
-        )
+        if info is not None and "auth_data" in info:
+            updatevalues["auth_data"] = json.dumps(info["auth_data"])
+        if version_etag is not None:
+            updatevalues["etag"] = version_etag
+
+        if updatevalues:
+            return self.db.simple_update(
+                table="e2e_room_keys_versions",
+                keyvalues={"user_id": user_id, "version": version},
+                updatevalues=updatevalues,
+                desc="update_e2e_room_keys_version",
+            )
 
     @trace
     def delete_e2e_room_keys_version(self, user_id, version=None):
@@ -322,16 +416,24 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         def _delete_e2e_room_keys_version_txn(txn):
             if version is None:
                 this_version = self._get_current_version(txn, user_id)
+                if this_version is None:
+                    raise StoreError(404, "No current backup version")
             else:
                 this_version = version
 
-            return self._simple_update_one_txn(
+            self.db.simple_delete_txn(
+                txn,
+                table="e2e_room_keys",
+                keyvalues={"user_id": user_id, "version": this_version},
+            )
+
+            return self.db.simple_update_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version},
                 updatevalues={"deleted": 1},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
         )
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
new file mode 100644
index 0000000000..20698bfd16
--- /dev/null
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -0,0 +1,721 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, List
+
+from six import iteritems
+
+from canonicaljson import encode_canonical_json, json
+
+from twisted.enterprise.adbapi import Connection
+from twisted.internet import defer
+
+from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
+
+
+class EndToEndKeyWorkerStore(SQLBaseStore):
+    @trace
+    @defer.inlineCallbacks
+    def get_e2e_device_keys(
+        self, query_list, include_all_devices=False, include_deleted_devices=False
+    ):
+        """Fetch a list of device keys.
+        Args:
+            query_list(list): List of pairs of user_ids and device_ids.
+            include_all_devices (bool): whether to include entries for devices
+                that don't have device keys
+            include_deleted_devices (bool): whether to include null entries for
+                devices which no longer exist (but were in the query_list).
+                This option only takes effect if include_all_devices is true.
+        Returns:
+            Dict mapping from user-id to dict mapping from device_id to
+            key data.  The key data will be a dict in the same format as the
+            DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
+        """
+        set_tag("query_list", query_list)
+        if not query_list:
+            return {}
+
+        results = yield self.db.runInteraction(
+            "get_e2e_device_keys",
+            self._get_e2e_device_keys_txn,
+            query_list,
+            include_all_devices,
+            include_deleted_devices,
+        )
+
+        # Build the result structure, un-jsonify the results, and add the
+        # "unsigned" section
+        rv = {}
+        for user_id, device_keys in iteritems(results):
+            rv[user_id] = {}
+            for device_id, device_info in iteritems(device_keys):
+                r = db_to_json(device_info.pop("key_json"))
+                r["unsigned"] = {}
+                display_name = device_info["device_display_name"]
+                if display_name is not None:
+                    r["unsigned"]["device_display_name"] = display_name
+                if "signatures" in device_info:
+                    for sig_user_id, sigs in device_info["signatures"].items():
+                        r.setdefault("signatures", {}).setdefault(
+                            sig_user_id, {}
+                        ).update(sigs)
+                rv[user_id][device_id] = r
+
+        return rv
+
+    @trace
+    def _get_e2e_device_keys_txn(
+        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+    ):
+        set_tag("include_all_devices", include_all_devices)
+        set_tag("include_deleted_devices", include_deleted_devices)
+
+        query_clauses = []
+        query_params = []
+        signature_query_clauses = []
+        signature_query_params = []
+
+        if include_all_devices is False:
+            include_deleted_devices = False
+
+        if include_deleted_devices:
+            deleted_devices = set(query_list)
+
+        for (user_id, device_id) in query_list:
+            query_clause = "user_id = ?"
+            query_params.append(user_id)
+            signature_query_clause = "target_user_id = ?"
+            signature_query_params.append(user_id)
+
+            if device_id is not None:
+                query_clause += " AND device_id = ?"
+                query_params.append(device_id)
+                signature_query_clause += " AND target_device_id = ?"
+                signature_query_params.append(device_id)
+
+            signature_query_clause += " AND user_id = ?"
+            signature_query_params.append(user_id)
+
+            query_clauses.append(query_clause)
+            signature_query_clauses.append(signature_query_clause)
+
+        sql = (
+            "SELECT user_id, device_id, "
+            "    d.display_name AS device_display_name, "
+            "    k.key_json"
+            " FROM devices d"
+            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
+            " WHERE %s AND NOT d.hidden"
+        ) % (
+            "LEFT" if include_all_devices else "INNER",
+            " OR ".join("(" + q + ")" for q in query_clauses),
+        )
+
+        txn.execute(sql, query_params)
+        rows = self.db.cursor_to_dict(txn)
+
+        result = {}
+        for row in rows:
+            if include_deleted_devices:
+                deleted_devices.remove((row["user_id"], row["device_id"]))
+            result.setdefault(row["user_id"], {})[row["device_id"]] = row
+
+        if include_deleted_devices:
+            for user_id, device_id in deleted_devices:
+                result.setdefault(user_id, {})[device_id] = None
+
+        # get signatures on the device
+        signature_sql = ("SELECT *  FROM e2e_cross_signing_signatures WHERE %s") % (
+            " OR ".join("(" + q + ")" for q in signature_query_clauses)
+        )
+
+        txn.execute(signature_sql, signature_query_params)
+        rows = self.db.cursor_to_dict(txn)
+
+        # add each cross-signing signature to the correct device in the result dict.
+        for row in rows:
+            signing_user_id = row["user_id"]
+            signing_key_id = row["key_id"]
+            target_user_id = row["target_user_id"]
+            target_device_id = row["target_device_id"]
+            signature = row["signature"]
+
+            target_user_result = result.get(target_user_id)
+            if not target_user_result:
+                continue
+
+            target_device_result = target_user_result.get(target_device_id)
+            if not target_device_result:
+                # note that target_device_result will be None for deleted devices.
+                continue
+
+            target_device_signatures = target_device_result.setdefault("signatures", {})
+            signing_user_signatures = target_device_signatures.setdefault(
+                signing_user_id, {}
+            )
+            signing_user_signatures[signing_key_id] = signature
+
+        log_kv(result)
+        return result
+
+    @defer.inlineCallbacks
+    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+        """Retrieve a number of one-time keys for a user
+
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            key_ids(list[str]): list of key ids (excluding algorithm) to
+                retrieve
+
+        Returns:
+            deferred resolving to Dict[(str, str), str]: map from (algorithm,
+            key_id) to json string for key
+        """
+
+        rows = yield self.db.simple_select_many_batch(
+            table="e2e_one_time_keys_json",
+            column="key_id",
+            iterable=key_ids,
+            retcols=("algorithm", "key_id", "key_json"),
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            desc="add_e2e_one_time_keys_check",
+        )
+        result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+        log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
+        return result
+
+    @defer.inlineCallbacks
+    def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+        """Insert some new one time keys for a device. Errors if any of the
+        keys already exist.
+
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            time_now(long): insertion time to record (ms since epoch)
+            new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+                (algorithm, key_id, key json)
+        """
+
+        def _add_e2e_one_time_keys(txn):
+            set_tag("user_id", user_id)
+            set_tag("device_id", device_id)
+            set_tag("new_keys", new_keys)
+            # We are protected from race between lookup and insertion due to
+            # a unique constraint. If there is a race of two calls to
+            # `add_e2e_one_time_keys` then they'll conflict and we will only
+            # insert one set.
+            self.db.simple_insert_many_txn(
+                txn,
+                table="e2e_one_time_keys_json",
+                values=[
+                    {
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                        "key_id": key_id,
+                        "ts_added_ms": time_now,
+                        "key_json": json_bytes,
+                    }
+                    for algorithm, key_id, json_bytes in new_keys
+                ],
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id)
+            )
+
+        yield self.db.runInteraction(
+            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+        )
+
+    @cached(max_entries=10000)
+    def count_e2e_one_time_keys(self, user_id, device_id):
+        """ Count the number of one time keys the server has for a device
+        Returns:
+            Dict mapping from algorithm to number of keys for that algorithm.
+        """
+
+        def _count_e2e_one_time_keys(txn):
+            sql = (
+                "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
+                " WHERE user_id = ? AND device_id = ?"
+                " GROUP BY algorithm"
+            )
+            txn.execute(sql, (user_id, device_id))
+            result = {}
+            for algorithm, key_count in txn:
+                result[algorithm] = key_count
+            return result
+
+        return self.db.runInteraction(
+            "count_e2e_one_time_keys", _count_e2e_one_time_keys
+        )
+
+    @defer.inlineCallbacks
+    def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+        """Returns a user's cross-signing key.
+
+        Args:
+            user_id (str): the user whose key is being requested
+            key_type (str): the type of key that is being requested: either 'master'
+                for a master key, 'self_signing' for a self-signing key, or
+                'user_signing' for a user-signing key
+            from_user_id (str): if specified, signatures made by this user on
+                the self-signing key will be included in the result
+
+        Returns:
+            dict of the key data or None if not found
+        """
+        res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+        user_keys = res.get(user_id)
+        if not user_keys:
+            return None
+        return user_keys.get(key_type)
+
+    @cached(num_args=1)
+    def _get_bare_e2e_cross_signing_keys(self, user_id):
+        """Dummy function.  Only used to make a cache for
+        _get_bare_e2e_cross_signing_keys_bulk.
+        """
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_bare_e2e_cross_signing_keys",
+        list_name="user_ids",
+        num_args=1,
+    )
+    def _get_bare_e2e_cross_signing_keys_bulk(
+        self, user_ids: List[str]
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing keys for a set of users.  The output of this
+        function should be passed to _get_e2e_cross_signing_signatures_txn if
+        the signatures for the calling user need to be fetched.
+
+        Args:
+            user_ids (list[str]): the users whose keys are being requested
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  If a user's cross-signing keys were not found, either
+                their user ID will not be in the dict, or their user ID will map
+                to None.
+
+        """
+        return self.db.runInteraction(
+            "get_bare_e2e_cross_signing_keys_bulk",
+            self._get_bare_e2e_cross_signing_keys_bulk_txn,
+            user_ids,
+        )
+
+    def _get_bare_e2e_cross_signing_keys_bulk_txn(
+        self, txn: Connection, user_ids: List[str],
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing keys for a set of users.  The output of this
+        function should be passed to _get_e2e_cross_signing_signatures_txn if
+        the signatures for the calling user need to be fetched.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            user_ids (list[str]): the users whose keys are being requested
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  If a user's cross-signing keys were not found, their user
+                ID will not be in the dict.
+
+        """
+        result = {}
+
+        for user_chunk in batch_iter(user_ids, 100):
+            clause, params = make_in_list_sql_clause(
+                txn.database_engine, "k.user_id", user_chunk
+            )
+            sql = (
+                """
+                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+                  FROM e2e_cross_signing_keys k
+                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+                                FROM e2e_cross_signing_keys
+                               GROUP BY user_id, keytype) s
+                 USING (user_id, stream_id, keytype)
+                 WHERE
+            """
+                + clause
+            )
+
+            txn.execute(sql, params)
+            rows = self.db.cursor_to_dict(txn)
+
+            for row in rows:
+                user_id = row["user_id"]
+                key_type = row["keytype"]
+                key = json.loads(row["keydata"])
+                user_info = result.setdefault(user_id, {})
+                user_info[key_type] = key
+
+        return result
+
+    def _get_e2e_cross_signing_signatures_txn(
+        self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing signatures made by a user on a set of keys.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+                key data.  This dict will be modified to add signatures.
+            from_user_id (str): fetch the signatures made by this user
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  The return value will be the same as the keys argument,
+                with the modifications included.
+        """
+
+        # find out what cross-signing keys (a.k.a. devices) we need to get
+        # signatures for.  This is a map of (user_id, device_id) to key type
+        # (device_id is the key's public part).
+        devices = {}
+
+        for user_id, user_info in keys.items():
+            if user_info is None:
+                continue
+            for key_type, key in user_info.items():
+                device_id = None
+                for k in key["keys"].values():
+                    device_id = k
+                devices[(user_id, device_id)] = key_type
+
+        for batch in batch_iter(devices.keys(), size=100):
+            sql = """
+                SELECT target_user_id, target_device_id, key_id, signature
+                  FROM e2e_cross_signing_signatures
+                 WHERE user_id = ?
+                   AND (%s)
+            """ % (
+                " OR ".join(
+                    "(target_user_id = ? AND target_device_id = ?)" for _ in batch
+                )
+            )
+            query_params = [from_user_id]
+            for item in batch:
+                # item is a (user_id, device_id) tuple
+                query_params.extend(item)
+
+            txn.execute(sql, query_params)
+            rows = self.db.cursor_to_dict(txn)
+
+            # and add the signatures to the appropriate keys
+            for row in rows:
+                key_id = row["key_id"]
+                target_user_id = row["target_user_id"]
+                target_device_id = row["target_device_id"]
+                key_type = devices[(target_user_id, target_device_id)]
+                # We need to copy everything, because the result may have come
+                # from the cache.  dict.copy only does a shallow copy, so we
+                # need to recursively copy the dicts that will be modified.
+                user_info = keys[target_user_id] = keys[target_user_id].copy()
+                target_user_key = user_info[key_type] = user_info[key_type].copy()
+                if "signatures" in target_user_key:
+                    signatures = target_user_key["signatures"] = target_user_key[
+                        "signatures"
+                    ].copy()
+                    if from_user_id in signatures:
+                        user_sigs = signatures[from_user_id] = signatures[from_user_id]
+                        user_sigs[key_id] = row["signature"]
+                    else:
+                        signatures[from_user_id] = {key_id: row["signature"]}
+                else:
+                    target_user_key["signatures"] = {
+                        from_user_id: {key_id: row["signature"]}
+                    }
+
+        return keys
+
+    @defer.inlineCallbacks
+    def get_e2e_cross_signing_keys_bulk(
+        self, user_ids: List[str], from_user_id: str = None
+    ) -> defer.Deferred:
+        """Returns the cross-signing keys for a set of users.
+
+        Args:
+            user_ids (list[str]): the users whose keys are being requested
+            from_user_id (str): if specified, signatures made by this user on
+                the self-signing keys will be included in the result
+
+        Returns:
+            Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+                key data.  If a user's cross-signing keys were not found, either
+                their user ID will not be in the dict, or their user ID will map
+                to None.
+        """
+
+        result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+        if from_user_id:
+            result = yield self.db.runInteraction(
+                "get_e2e_cross_signing_signatures",
+                self._get_e2e_cross_signing_signatures_txn,
+                result,
+                from_user_id,
+            )
+
+        return result
+
+    def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
+        """Return a list of changes from the user signature stream to notify remotes.
+        Note that the user signature stream represents when a user signs their
+        device with their user-signing key, which is not published to other
+        users or servers, so no `destination` is needed in the returned
+        list. However, this is needed to poke workers.
+
+        Args:
+            from_key (int): the stream ID to start at (exclusive)
+            to_key (int): the stream ID to end at (inclusive)
+
+        Returns:
+            Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+        """
+        sql = """
+            SELECT stream_id, from_user_id AS user_id
+            FROM user_signature_stream
+            WHERE ? < stream_id AND stream_id <= ?
+            ORDER BY stream_id ASC
+            LIMIT ?
+        """
+        return self.db.execute(
+            "get_all_user_signature_changes_for_remotes",
+            None,
+            sql,
+            from_key,
+            to_key,
+            limit,
+        )
+
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+
+        def _set_e2e_device_keys_txn(txn):
+            set_tag("user_id", user_id)
+            set_tag("device_id", device_id)
+            set_tag("time_now", time_now)
+            set_tag("device_keys", device_keys)
+
+            old_key_json = self.db.simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            # In py3 we need old_key_json to match new_key_json type. The DB
+            # returns unicode while encode_canonical_json returns bytes.
+            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+            if old_key_json == new_key_json:
+                log_kv({"Message": "Device key already stored."})
+                return False
+
+            self.db.simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                values={"ts_added_ms": time_now, "key_json": new_key_json},
+            )
+            log_kv({"message": "Device keys stored."})
+            return True
+
+        return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+
+    def claim_e2e_one_time_keys(self, query_list):
+        """Take a list of one time keys out of the database"""
+
+        @trace
+        def _claim_e2e_one_time_keys(txn):
+            sql = (
+                "SELECT key_id, key_json FROM e2e_one_time_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " LIMIT 1"
+            )
+            result = {}
+            delete = []
+            for user_id, device_id, algorithm in query_list:
+                user_result = result.setdefault(user_id, {})
+                device_result = user_result.setdefault(device_id, {})
+                txn.execute(sql, (user_id, device_id, algorithm))
+                for key_id, key_json in txn:
+                    device_result[algorithm + ":" + key_id] = key_json
+                    delete.append((user_id, device_id, algorithm, key_id))
+            sql = (
+                "DELETE FROM e2e_one_time_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " AND key_id = ?"
+            )
+            for user_id, device_id, algorithm, key_id in delete:
+                log_kv(
+                    {
+                        "message": "Executing claim e2e_one_time_keys transaction on database."
+                    }
+                )
+                txn.execute(sql, (user_id, device_id, algorithm, key_id))
+                log_kv({"message": "finished executing and invalidating cache"})
+                self._invalidate_cache_and_stream(
+                    txn, self.count_e2e_one_time_keys, (user_id, device_id)
+                )
+            return result
+
+        return self.db.runInteraction(
+            "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+        )
+
+    def delete_e2e_keys_by_device(self, user_id, device_id):
+        def delete_e2e_keys_by_device_txn(txn):
+            log_kv(
+                {
+                    "message": "Deleting keys for device",
+                    "device_id": device_id,
+                    "user_id": user_id,
+                }
+            )
+            self.db.simple_delete_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self.db.simple_delete_txn(
+                txn,
+                table="e2e_one_time_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id)
+            )
+
+        return self.db.runInteraction(
+            "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
+        )
+
+    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+        """Set a user's cross-signing key.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            user_id (str): the user to set the signing key for
+            key_type (str): the type of key that is being set: either 'master'
+                for a master key, 'self_signing' for a self-signing key, or
+                'user_signing' for a user-signing key
+            key (dict): the key data
+        """
+        # the 'key' dict will look something like:
+        # {
+        #   "user_id": "@alice:example.com",
+        #   "usage": ["self_signing"],
+        #   "keys": {
+        #     "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
+        #   },
+        #   "signatures": {
+        #     "@alice:example.com": {
+        #       "ed25519:base64+master+public+key": "base64+signature"
+        #     }
+        #   }
+        # }
+        # The "keys" property must only have one entry, which will be the public
+        # key, so we just grab the first value in there
+        pubkey = next(iter(key["keys"].values()))
+
+        # The cross-signing keys need to occupy the same namespace as devices,
+        # since signatures are identified by device ID.  So add an entry to the
+        # device table to make sure that we don't have a collision with device
+        # IDs.
+        # We only need to do this for local users, since remote servers should be
+        # responsible for checking this for their own users.
+        if self.hs.is_mine_id(user_id):
+            self.db.simple_insert_txn(
+                txn,
+                "devices",
+                values={
+                    "user_id": user_id,
+                    "device_id": pubkey,
+                    "display_name": key_type + " signing key",
+                    "hidden": True,
+                },
+            )
+
+        # and finally, store the key itself
+        with self._cross_signing_id_gen.get_next() as stream_id:
+            self.db.simple_insert_txn(
+                txn,
+                "e2e_cross_signing_keys",
+                values={
+                    "user_id": user_id,
+                    "keytype": key_type,
+                    "keydata": json.dumps(key),
+                    "stream_id": stream_id,
+                },
+            )
+
+        self._invalidate_cache_and_stream(
+            txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+        )
+
+    def set_e2e_cross_signing_key(self, user_id, key_type, key):
+        """Set a user's cross-signing key.
+
+        Args:
+            user_id (str): the user to set the user-signing key for
+            key_type (str): the type of cross-signing key to set
+            key (dict): the key data
+        """
+        return self.db.runInteraction(
+            "add_e2e_cross_signing_key",
+            self._set_e2e_cross_signing_key_txn,
+            user_id,
+            key_type,
+            key,
+        )
+
+    def store_e2e_cross_signing_signatures(self, user_id, signatures):
+        """Stores cross-signing signatures.
+
+        Args:
+            user_id (str): the user who made the signatures
+            signatures (iterable[SignatureListItem]): signatures to add
+        """
+        return self.db.simple_insert_many(
+            "e2e_cross_signing_signatures",
+            [
+                {
+                    "user_id": user_id,
+                    "key_id": item.signing_key_id,
+                    "target_user_id": item.target_user_id,
+                    "target_device_id": item.target_device_id,
+                    "signature": item.signature,
+                }
+                for item in signatures
+            ],
+            "add_e2e_signing_key",
+        )
diff --git a/synapse/storage/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 4f500d893e..24ce8c4330 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -12,22 +12,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import itertools
 import logging
-import random
+from typing import Dict, List, Optional, Set, Tuple
 
-from six.moves import range
 from six.moves.queue import Empty, PriorityQueue
 
-from unpaddedbase64 import encode_base64
-
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
@@ -47,37 +47,55 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             event_ids, include_given=include_given
         ).addCallback(self.get_events_as_list)
 
-    def get_auth_chain_ids(self, event_ids, include_given=False):
+    def get_auth_chain_ids(
+        self,
+        event_ids: List[str],
+        include_given: bool = False,
+        ignore_events: Optional[Set[str]] = None,
+    ):
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
-            event_ids (list): state events
-            include_given (bool): include the given events in result
+            event_ids: state events
+            include_given: include the given events in result
+            ignore_events: Set of events to exclude from the returned auth
+                chain. This is useful if the caller will just discard the
+                given events anyway, and saves us from figuring out their auth
+                chains if not required.
 
         Returns:
             list of event_ids
         """
-        return self.runInteraction(
-            "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+        return self.db.runInteraction(
+            "get_auth_chain_ids",
+            self._get_auth_chain_ids_txn,
+            event_ids,
+            include_given,
+            ignore_events,
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+        if ignore_events is None:
+            ignore_events = set()
+
         if include_given:
             results = set(event_ids)
         else:
             results = set()
 
-        base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
+        base_sql = "SELECT auth_id FROM event_auth WHERE "
 
         front = set(event_ids)
         while front:
             new_front = set()
-            front_list = list(front)
-            chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
-            for chunk in chunks:
-                txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
-                new_front.update([r[0] for r in txn])
+            for chunk in batch_iter(front, 100):
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "event_id", chunk
+                )
+                txn.execute(base_sql + clause, args)
+                new_front.update(r[0] for r in txn)
 
+            new_front -= ignore_events
             new_front -= results
 
             front = new_front
@@ -85,13 +103,170 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
+    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+        """Given sets of state events figure out the auth chain difference (as
+        per state res v2 algorithm).
+
+        This equivalent to fetching the full auth chain for each set of state
+        and returning the events that don't appear in each and every auth
+        chain.
+
+        Returns:
+            Deferred[Set[str]]
+        """
+
+        return self.db.runInteraction(
+            "get_auth_chain_difference",
+            self._get_auth_chain_difference_txn,
+            state_sets,
+        )
+
+    def _get_auth_chain_difference_txn(
+        self, txn, state_sets: List[Set[str]]
+    ) -> Set[str]:
+
+        # Algorithm Description
+        # ~~~~~~~~~~~~~~~~~~~~~
+        #
+        # The idea here is to basically walk the auth graph of each state set in
+        # tandem, keeping track of which auth events are reachable by each state
+        # set. If we reach an auth event we've already visited (via a different
+        # state set) then we mark that auth event and all ancestors as reachable
+        # by the state set. This requires that we keep track of the auth chains
+        # in memory.
+        #
+        # Doing it in a such a way means that we can stop early if all auth
+        # events we're currently walking are reachable by all state sets.
+        #
+        # *Note*: We can't stop walking an event's auth chain if it is reachable
+        # by all state sets. This is because other auth chains we're walking
+        # might be reachable only via the original auth chain. For example,
+        # given the following auth chain:
+        #
+        #       A -> C -> D -> E
+        #           /         /
+        #       B -´---------´
+        #
+        # and state sets {A} and {B} then walking the auth chains of A and B
+        # would immediately show that C is reachable by both. However, if we
+        # stopped at C then we'd only reach E via the auth chain of B and so E
+        # would errornously get included in the returned difference.
+        #
+        # The other thing that we do is limit the number of auth chains we walk
+        # at once, due to practical limits (i.e. we can only query the database
+        # with a limited set of parameters). We pick the auth chains we walk
+        # each iteration based on their depth, in the hope that events with a
+        # lower depth are likely reachable by those with higher depths.
+        #
+        # We could use any ordering that we believe would give a rough
+        # topological ordering, e.g. origin server timestamp. If the ordering
+        # chosen is not topological then the algorithm still produces the right
+        # result, but perhaps a bit more inefficiently. This is why it is safe
+        # to use "depth" here.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Dict from events in auth chains to which sets *cannot* reach them.
+        # I.e. if the set is empty then all sets can reach the event.
+        event_to_missing_sets = {
+            event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+            for event_id in initial_events
+        }
+
+        # The sorted list of events whose auth chains we should walk.
+        search = []  # type: List[Tuple[int, str]]
+
+        # We need to get the depth of the initial events for sorting purposes.
+        sql = """
+            SELECT depth, event_id FROM events
+            WHERE %s
+        """
+        # the list can be huge, so let's avoid looking them all up in one massive
+        # query.
+        for batch in batch_iter(initial_events, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            # I think building a temporary list with fetchall is more efficient than
+            # just `search.extend(txn)`, but this is unconfirmed
+            search.extend(txn.fetchall())
+
+        # sort by depth
+        search.sort()
+
+        # Map from event to its auth events
+        event_to_auth_events = {}  # type: Dict[str, Set[str]]
+
+        base_sql = """
+            SELECT a.event_id, auth_id, depth
+            FROM event_auth AS a
+            INNER JOIN events AS e ON (e.event_id = a.auth_id)
+            WHERE
+        """
+
+        while search:
+            # Check whether all our current walks are reachable by all state
+            # sets. If so we can bail.
+            if all(not event_to_missing_sets[eid] for _, eid in search):
+                break
+
+            # Fetch the auth events and their depths of the N last events we're
+            # currently walking
+            search, chunk = search[:-100], search[-100:]
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+            )
+            txn.execute(base_sql + clause, args)
+
+            for event_id, auth_event_id, auth_event_depth in txn:
+                event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+                sets = event_to_missing_sets.get(auth_event_id)
+                if sets is None:
+                    # First time we're seeing this event, so we add it to the
+                    # queue of things to fetch.
+                    search.append((auth_event_depth, auth_event_id))
+
+                    # Assume that this event is unreachable from any of the
+                    # state sets until proven otherwise
+                    sets = event_to_missing_sets[auth_event_id] = set(
+                        range(len(state_sets))
+                    )
+                else:
+                    # We've previously seen this event, so look up its auth
+                    # events and recursively mark all ancestors as reachable
+                    # by the current event's state set.
+                    a_ids = event_to_auth_events.get(auth_event_id)
+                    while a_ids:
+                        new_aids = set()
+                        for a_id in a_ids:
+                            event_to_missing_sets[a_id].intersection_update(
+                                event_to_missing_sets[event_id]
+                            )
+
+                            b = event_to_auth_events.get(a_id)
+                            if b:
+                                new_aids.update(b)
+
+                        a_ids = new_aids
+
+                # Mark that the auth event is reachable by the approriate sets.
+                sets.intersection_update(event_to_missing_sets[event_id])
+
+            search.sort()
+
+        # Return all events where not all sets can reach them.
+        return {eid for eid, n in event_to_missing_sets.items() if n}
+
     def get_oldest_events_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
         )
 
     def get_oldest_events_with_depth_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_oldest_events_with_depth_in_room",
             self.get_oldest_events_with_depth_in_room_txn,
             room_id,
@@ -122,7 +297,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns
             Deferred[int]
         """
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="events",
             column="event_id",
             iterable=event_ids,
@@ -136,15 +311,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             return max(row["depth"] for row in rows)
 
     def _get_oldest_events_in_room_txn(self, txn, room_id):
-        return self._simple_select_onecol_txn(
+        return self.db.simple_select_onecol_txn(
             txn,
             table="event_backward_extremities",
             keyvalues={"room_id": room_id},
             retcol="event_id",
         )
 
-    @defer.inlineCallbacks
-    def get_prev_events_for_room(self, room_id):
+    def get_prev_events_for_room(self, room_id: str):
         """
         Gets a subset of the current forward extremities in the given room.
 
@@ -155,47 +329,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             room_id (str): room_id
 
         Returns:
-            Deferred[list[(str, dict[str, str], int)]]
-                for each event, a tuple of (event_id, hashes, depth)
-                where *hashes* is a map from algorithm to hash.
+            Deferred[List[str]]: the event ids of the forward extremites
+
         """
-        res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
-        if len(res) > 10:
-            # Sort by reverse depth, so we point to the most recent.
-            res.sort(key=lambda a: -a[2])
 
-            # we use half of the limit for the actual most recent events, and
-            # the other half to randomly point to some of the older events, to
-            # make sure that we don't completely ignore the older events.
-            res = res[0:5] + random.sample(res[5:], 5)
+        return self.db.runInteraction(
+            "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
+        )
 
-        return res
+    def _get_prev_events_for_room_txn(self, txn, room_id: str):
+        # we just use the 10 newest events. Older events will become
+        # prev_events of future events.
 
-    def get_latest_event_ids_and_hashes_in_room(self, room_id):
+        sql = """
+            SELECT e.event_id FROM event_forward_extremities AS f
+            INNER JOIN events AS e USING (event_id)
+            WHERE f.room_id = ?
+            ORDER BY e.depth DESC
+            LIMIT 10
         """
-        Gets the current forward extremities in the given room
 
-        Args:
-            room_id (str): room_id
+        txn.execute(sql, (room_id,))
 
-        Returns:
-            Deferred[list[(str, dict[str, str], int)]]
-                for each event, a tuple of (event_id, hashes, depth)
-                where *hashes* is a map from algorithm to hash.
-        """
+        return [row[0] for row in txn]
 
-        return self.runInteraction(
-            "get_latest_event_ids_and_hashes_in_room",
-            self._get_latest_event_ids_and_hashes_in_room,
-            room_id,
-        )
-
-    def get_rooms_with_many_extremities(self, min_count, limit):
+    def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
         """Get the top rooms with at least N extremities.
 
         Args:
             min_count (int): The minimum number of extremities
             limit (int): The maximum number of rooms to return.
+            room_id_filter (iterable[str]): room_ids to exclude from the results
 
         Returns:
             Deferred[list]: At most `limit` room IDs that have at least
@@ -203,60 +367,49 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         def _get_rooms_with_many_extremities_txn(txn):
+            where_clause = "1=1"
+            if room_id_filter:
+                where_clause = "room_id NOT IN (%s)" % (
+                    ",".join("?" for _ in room_id_filter),
+                )
+
             sql = """
                 SELECT room_id FROM event_forward_extremities
+                WHERE %s
                 GROUP BY room_id
                 HAVING count(*) > ?
                 ORDER BY count(*) DESC
                 LIMIT ?
-            """
+            """ % (
+                where_clause,
+            )
 
-            txn.execute(sql, (min_count, limit))
+            query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
+            txn.execute(sql, query_args)
             return [room_id for room_id, in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
         )
 
     @cached(max_entries=5000, iterable=True)
     def get_latest_event_ids_in_room(self, room_id):
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
             retcol="event_id",
             desc="get_latest_event_ids_in_room",
         )
 
-    def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
-        sql = (
-            "SELECT e.event_id, e.depth FROM events as e "
-            "INNER JOIN event_forward_extremities as f "
-            "ON e.event_id = f.event_id "
-            "AND e.room_id = f.room_id "
-            "WHERE f.room_id = ?"
-        )
-
-        txn.execute(sql, (room_id,))
-
-        results = []
-        for event_id, depth in txn.fetchall():
-            hashes = self._get_event_reference_hashes_txn(txn, event_id)
-            prev_hashes = {
-                k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
-            }
-            results.append((event_id, prev_hashes, depth))
-
-        return results
-
     def get_min_depth(self, room_id):
         """ For hte given room, get the minimum depth we have seen for it.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
     def _get_min_depth_interaction(self, txn, room_id):
-        min_depth = self._simple_select_one_onecol_txn(
+        min_depth = self.db.simple_select_one_onecol_txn(
             txn,
             table="room_depth",
             keyvalues={"room_id": room_id},
@@ -322,7 +475,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
@@ -337,7 +490,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             limit (int)
         """
         return (
-            self.runInteraction(
+            self.db.runInteraction(
                 "get_backfill_events",
                 self._get_backfill_events,
                 room_id,
@@ -349,9 +502,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
 
     def _get_backfill_events(self, txn, room_id, event_list, limit):
-        logger.debug(
-            "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
-        )
+        logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
 
         event_results = set()
 
@@ -370,7 +521,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         queue = PriorityQueue()
 
         for event_id in event_list:
-            depth = self._simple_select_one_onecol_txn(
+            depth = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="events",
                 keyvalues={"event_id": event_id, "room_id": room_id},
@@ -402,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
     @defer.inlineCallbacks
     def get_missing_events(self, room_id, earliest_events, latest_events, limit):
-        ids = yield self.runInteraction(
+        ids = yield self.db.runInteraction(
             "get_missing_events",
             self._get_missing_events,
             room_id,
@@ -432,7 +583,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                     query, (room_id, event_id, False, limit - len(event_results))
                 )
 
-                new_results = set(t[0] for t in txn) - seen_events
+                new_results = {t[0] for t in txn} - seen_events
 
                 new_front |= new_results
                 seen_events |= new_results
@@ -455,7 +606,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             Deferred[list[str]]
         """
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="event_edges",
             column="prev_event_id",
             iterable=event_ids,
@@ -478,10 +629,10 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, db_conn, hs):
-        super(EventFederationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventFederationStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
@@ -489,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore):
             self._delete_old_forward_extrem_cache, 60 * 60 * 1000
         )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
-        min_depth = self._get_min_depth_interaction(txn, room_id)
-
-        if min_depth and depth >= min_depth:
-            return
-
-        self._simple_upsert_txn(
-            txn,
-            table="room_depth",
-            keyvalues={"room_id": room_id},
-            values={"min_depth": depth},
-        )
-
-    def _handle_mult_prev_events(self, txn, events):
-        """
-        For the given event, update the event edges table and forward and
-        backward extremities tables.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="event_edges",
-            values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
-                for ev in events
-                for e_id in ev.prev_event_ids()
-            ],
-        )
-
-        self._update_backward_extremeties(txn, events)
-
-    def _update_backward_extremeties(self, txn, events):
-        """Updates the event_backward_extremities tables based on the new/updated
-        events being persisted.
-
-        This is called for new events *and* for events that were outliers, but
-        are now being persisted as non-outliers.
-
-        Forward extremities are handled when we first start persisting the events.
-        """
-        events_by_room = {}
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
-        query = (
-            "INSERT INTO event_backward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-            " )"
-            " AND NOT EXISTS ("
-            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            " AND outlier = ?"
-            " )"
-        )
-
-        txn.executemany(
-            query,
-            [
-                (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
-                for ev in events
-                for e_id in ev.prev_event_ids()
-                if not ev.internal_metadata.is_outlier()
-            ],
-        )
-
-        query = (
-            "DELETE FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-        )
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id)
-                for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ],
-        )
-
     def _delete_old_forward_extrem_cache(self):
         def _delete_old_forward_extrem_cache_txn(txn):
             # Delete entries older than a month, while making sure we don't delete
@@ -591,13 +659,13 @@ class EventFederationStore(EventFederationWorkerStore):
 
         return run_as_background_process(
             "delete_old_forward_extrem_cache",
-            self.runInteraction,
+            self.db.runInteraction,
             "_delete_old_forward_extrem_cache",
             _delete_old_forward_extrem_cache_txn,
         )
 
     def clean_room_for_join(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )
 
@@ -641,17 +709,17 @@ class EventFederationStore(EventFederationWorkerStore):
                 "max_stream_id_exclusive": min_stream_id,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_AUTH_STATE_ONLY, new_progress
             )
 
             return min_stream_id >= target_min_stream_id
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_AUTH_STATE_ONLY, delete_event_auth
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+            yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
 
         return batch_size
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index b01a12528f..8ad7a306f8 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
         self.stream_ordering_month_ago = None
@@ -93,7 +94,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def get_unread_event_push_actions_by_room_for_user(
         self, room_id, user_id, last_read_event_id
     ):
-        ret = yield self.runInteraction(
+        ret = yield self.db.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
             room_id,
@@ -177,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
 
-        ret = yield self.runInteraction("get_push_action_users_in_range", f)
+        ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
         return ret
 
     @defer.inlineCallbacks
@@ -230,7 +231,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.runInteraction(
+        after_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
         )
 
@@ -259,7 +260,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.runInteraction(
+        no_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
         )
 
@@ -332,7 +333,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.runInteraction(
+        after_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
         )
 
@@ -361,7 +362,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.runInteraction(
+        no_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
         )
 
@@ -411,7 +412,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, min_stream_ordering))
             return bool(txn.fetchone())
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_if_maybe_push_in_range_for_user",
             _get_if_maybe_push_in_range_for_user_txn,
         )
@@ -446,7 +447,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
 
         def _add_push_actions_to_staging_txn(txn):
-            # We don't use _simple_insert_many here to avoid the overhead
+            # We don't use simple_insert_many here to avoid the overhead
             # of generating lists of dicts.
 
             sql = """
@@ -463,7 +464,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 ),
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_push_actions_to_staging", _add_push_actions_to_staging_txn
         )
 
@@ -477,7 +478,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
 
         try:
-            res = yield self._simple_delete(
+            res = yield self.db.simple_delete(
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
@@ -494,7 +495,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def _find_stream_orderings_for_times(self):
         return run_as_background_process(
             "event_push_action_stream_orderings",
-            self.runInteraction,
+            self.db.runInteraction,
             "_find_stream_orderings_for_times",
             self._find_stream_orderings_for_times_txn,
         )
@@ -530,7 +531,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             Deferred[int]: stream ordering of the first event received on/after
                 the timestamp
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_find_first_stream_ordering_after_ts_txn",
             self._find_first_stream_ordering_after_ts_txn,
             ts,
@@ -612,21 +613,38 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         return range_end
 
+    @defer.inlineCallbacks
+    def get_time_of_last_push_action_before(self, stream_ordering):
+        def f(txn):
+            sql = (
+                "SELECT e.received_ts"
+                " FROM event_push_actions AS ep"
+                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
+                " WHERE ep.stream_ordering > ?"
+                " ORDER BY ep.stream_ordering ASC"
+                " LIMIT 1"
+            )
+            txn.execute(sql, (stream_ordering,))
+            return txn.fetchone()
+
+        result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+        return result[0] if result else None
+
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             self.EPA_HIGHLIGHT_INDEX,
             index_name="event_push_actions_u_highlight",
             table="event_push_actions",
             columns=["user_id", "stream_ordering"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_push_actions_highlights_index",
             index_name="event_push_actions_highlights_index",
             table="event_push_actions",
@@ -639,69 +657,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             self._start_rotate_notifs, 30 * 60 * 1000
         )
 
-    def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
-        """Handles moving push actions from staging table to main
-        event_push_actions table for all events in `events_and_contexts`.
-
-        Also ensures that all events in `all_events_and_contexts` are removed
-        from the push action staging area.
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
-        """
-
-        sql = """
-            INSERT INTO event_push_actions (
-                room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
-            )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
-            FROM event_push_actions_staging
-            WHERE event_id = ?
-        """
-
-        if events_and_contexts:
-            txn.executemany(
-                sql,
-                (
-                    (
-                        event.room_id,
-                        event.internal_metadata.stream_ordering,
-                        event.depth,
-                        event.event_id,
-                    )
-                    for event, _ in events_and_contexts
-                ),
-            )
-
-        for event, _ in events_and_contexts:
-            user_ids = self._simple_select_onecol_txn(
-                txn,
-                table="event_push_actions_staging",
-                keyvalues={"event_id": event.event_id},
-                retcol="user_id",
-            )
-
-            for uid in user_ids:
-                txn.call_after(
-                    self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                    (event.room_id, uid),
-                )
-
-        # Now we delete the staging area for *all* events that were being
-        # persisted.
-        txn.executemany(
-            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
-            ((event.event_id,) for event, _ in all_events_and_contexts),
-        )
-
     @defer.inlineCallbacks
     def get_push_actions_for_user(
         self, user_id, before=None, limit=50, only_highlight=False
@@ -732,49 +687,23 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " LIMIT ?" % (before_clause,)
             )
             txn.execute(sql, args)
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        push_actions = yield self.runInteraction("get_push_actions_for_user", f)
+        push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
         for pa in push_actions:
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
 
     @defer.inlineCallbacks
-    def get_time_of_last_push_action_before(self, stream_ordering):
-        def f(txn):
-            sql = (
-                "SELECT e.received_ts"
-                " FROM event_push_actions AS ep"
-                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
-                " WHERE ep.stream_ordering > ?"
-                " ORDER BY ep.stream_ordering ASC"
-                " LIMIT 1"
-            )
-            txn.execute(sql, (stream_ordering,))
-            return txn.fetchone()
-
-        result = yield self.runInteraction("get_time_of_last_push_action_before", f)
-        return result[0] if result else None
-
-    @defer.inlineCallbacks
     def get_latest_push_action_stream_ordering(self):
         def f(txn):
             txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
             return txn.fetchone()
 
-        result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
-        return result[0] or 0
-
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
-        # Sad that we have to blow away the cache for the whole room here
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id,),
-        )
-        txn.execute(
-            "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
-            (room_id, event_id),
+        result = yield self.db.runInteraction(
+            "get_latest_push_action_stream_ordering", f
         )
+        return result[0] or 0
 
     def _remove_old_push_actions_before_txn(
         self, txn, room_id, user_id, stream_ordering
@@ -835,7 +764,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             while True:
                 logger.info("Rotating notifications")
 
-                caught_up = yield self.runInteraction(
+                caught_up = yield self.db.runInteraction(
                     "_rotate_notifs", self._rotate_notifs_txn
                 )
                 if caught_up:
@@ -849,7 +778,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         the archiving process has caught up or not.
         """
 
-        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+        old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
             keyvalues={},
@@ -868,7 +797,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
         stream_row = txn.fetchone()
         if stream_row:
-            offset_stream_ordering, = stream_row
+            (offset_stream_ordering,) = stream_row
             rotate_to_stream_ordering = min(
                 self.stream_ordering_day_ago, offset_stream_ordering
             )
@@ -885,7 +814,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         return caught_up
 
     def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
-        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+        old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
             keyvalues={},
@@ -917,7 +846,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
         # existing table.
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_push_summary",
             values=[
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
new file mode 100644
index 0000000000..a6572571b4
--- /dev/null
+++ b/synapse/storage/data_stores/main/events.py
@@ -0,0 +1,1620 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from collections import OrderedDict, namedtuple
+from functools import wraps
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+
+from six import integer_types, iteritems, text_type
+from six.moves import range
+
+import attr
+from canonicaljson import json
+from prometheus_client import Counter
+
+from twisted.internet import defer
+
+import synapse.metrics
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    Membership,
+    RelationTypes,
+)
+from synapse.api.room_versions import RoomVersions
+from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.events import EventBase  # noqa: F401
+from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.logging.utils import log_function
+from synapse.storage._base import make_in_list_sql_clause
+from synapse.storage.data_stores.main.search import SearchEntry
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import StateMap, get_domain_from_id
+from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.iterutils import batch_iter
+
+if TYPE_CHECKING:
+    from synapse.storage.data_stores.main import DataStore
+    from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
+event_counter = Counter(
+    "synapse_storage_events_persisted_events_sep",
+    "",
+    ["type", "origin_type", "origin_entity"],
+)
+
+
+def encode_json(json_object):
+    """
+    Encode a Python object as JSON and return it in a Unicode string.
+    """
+    out = frozendict_json_encoder.encode(json_object)
+    if isinstance(out, bytes):
+        out = out.decode("utf8")
+    return out
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+def _retry_on_integrity_error(func):
+    """Wraps a database function so that it gets retried on IntegrityError,
+    with `delete_existing=True` passed in.
+
+    Args:
+        func: function that returns a Deferred and accepts a `delete_existing` arg
+    """
+
+    @wraps(func)
+    @defer.inlineCallbacks
+    def f(self, *args, **kwargs):
+        try:
+            res = yield func(self, *args, delete_existing=False, **kwargs)
+        except self.database_engine.module.IntegrityError:
+            logger.exception("IntegrityError, retrying.")
+            res = yield func(self, *args, delete_existing=True, **kwargs)
+        return res
+
+    return f
+
+
+@attr.s(slots=True)
+class DeltaState:
+    """Deltas to use to update the `current_state_events` table.
+
+    Attributes:
+        to_delete: List of type/state_keys to delete from current state
+        to_insert: Map of state to upsert into current state
+        no_longer_in_room: The server is not longer in the room, so the room
+            should e.g. be removed from `current_state_events` table.
+    """
+
+    to_delete = attr.ib(type=List[Tuple[str, str]])
+    to_insert = attr.ib(type=StateMap[str])
+    no_longer_in_room = attr.ib(type=bool, default=False)
+
+
+class PersistEventsStore:
+    """Contains all the functions for writing events to the database.
+
+    Should only be instantiated on one process (when using a worker mode setup).
+
+    Note: This is not part of the `DataStore` mixin.
+    """
+
+    def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"):
+        self.hs = hs
+        self.db = db
+        self.store = main_data_store
+        self.database_engine = db.engine
+        self._clock = hs.get_clock()
+
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+        self.is_mine_id = hs.is_mine_id
+
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen = self.store._backfill_id_gen  # type: StreamIdGenerator
+        self._stream_id_gen = self.store._stream_id_gen  # type: StreamIdGenerator
+
+        # This should only exist on instances that are configured to write
+        assert (
+            hs.config.worker.writers.events == hs.get_instance_name()
+        ), "Can only instantiate EventsStore on master"
+
+    @_retry_on_integrity_error
+    @defer.inlineCallbacks
+    def _persist_events_and_state_updates(
+        self,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        current_state_for_room: Dict[str, StateMap[str]],
+        state_delta_for_room: Dict[str, DeltaState],
+        new_forward_extremeties: Dict[str, List[str]],
+        backfilled: bool = False,
+        delete_existing: bool = False,
+    ):
+        """Persist a set of events alongside updates to the current state and
+        forward extremities tables.
+
+        Args:
+            events_and_contexts:
+            current_state_for_room: Map from room_id to the current state of
+                the room based on forward extremities
+            state_delta_for_room: Map from room_id to the delta to apply to
+                room state
+            new_forward_extremities: Map from room_id to list of event IDs
+                that are the new forward extremities of the room.
+            backfilled
+            delete_existing
+
+        Returns:
+            Deferred: resolves when the events have been persisted
+        """
+
+        # We want to calculate the stream orderings as late as possible, as
+        # we only notify after all events with a lesser stream ordering have
+        # been persisted. I.e. if we spend 10s inside the with block then
+        # that will delay all subsequent events from being notified about.
+        # Hence why we do it down here rather than wrapping the entire
+        # function.
+        #
+        # Its safe to do this after calculating the state deltas etc as we
+        # only need to protect the *persistence* of the events. This is to
+        # ensure that queries of the form "fetch events since X" don't
+        # return events and stream positions after events that are still in
+        # flight, as otherwise subsequent requests "fetch event since Y"
+        # will not return those events.
+        #
+        # Note: Multiple instances of this function cannot be in flight at
+        # the same time for the same room.
+        if backfilled:
+            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+                len(events_and_contexts)
+            )
+        else:
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+                len(events_and_contexts)
+            )
+
+        with stream_ordering_manager as stream_orderings:
+            for (event, context), stream in zip(events_and_contexts, stream_orderings):
+                event.internal_metadata.stream_ordering = stream
+
+            yield self.db.runInteraction(
+                "persist_events",
+                self._persist_events_txn,
+                events_and_contexts=events_and_contexts,
+                backfilled=backfilled,
+                delete_existing=delete_existing,
+                state_delta_for_room=state_delta_for_room,
+                new_forward_extremeties=new_forward_extremeties,
+            )
+            persist_event_counter.inc(len(events_and_contexts))
+
+            if not backfilled:
+                # backfilled events have negative stream orderings, so we don't
+                # want to set the event_persisted_position to that.
+                synapse.metrics.event_persisted_position.set(
+                    events_and_contexts[-1][0].internal_metadata.stream_ordering
+                )
+
+            for event, context in events_and_contexts:
+                if context.app_service:
+                    origin_type = "local"
+                    origin_entity = context.app_service.id
+                elif self.hs.is_mine_id(event.sender):
+                    origin_type = "local"
+                    origin_entity = "*client*"
+                else:
+                    origin_type = "remote"
+                    origin_entity = get_domain_from_id(event.sender)
+
+                event_counter.labels(event.type, origin_type, origin_entity).inc()
+
+            for room_id, new_state in iteritems(current_state_for_room):
+                self.store.get_current_state_ids.prefill((room_id,), new_state)
+
+            for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+                self.store.get_latest_event_ids_in_room.prefill(
+                    (room_id,), list(latest_event_ids)
+                )
+
+    @defer.inlineCallbacks
+    def _get_events_which_are_prevs(self, event_ids):
+        """Filter the supplied list of event_ids to get those which are prev_events of
+        existing (non-outlier/rejected) events.
+
+        Args:
+            event_ids (Iterable[str]): event ids to filter
+
+        Returns:
+            Deferred[List[str]]: filtered event ids
+        """
+        results = []
+
+        def _get_events_which_are_prevs_txn(txn, batch):
+            sql = """
+            SELECT prev_event_id, internal_metadata
+            FROM event_edges
+                INNER JOIN events USING (event_id)
+                LEFT JOIN rejections USING (event_id)
+                LEFT JOIN event_json USING (event_id)
+            WHERE
+                NOT events.outlier
+                AND rejections.event_id IS NULL
+                AND
+            """
+
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "prev_event_id", batch
+            )
+
+            txn.execute(sql + clause, args)
+            results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
+
+        for chunk in batch_iter(event_ids, 100):
+            yield self.db.runInteraction(
+                "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
+            )
+
+        return results
+
+    @defer.inlineCallbacks
+    def _get_prevs_before_rejected(self, event_ids):
+        """Get soft-failed ancestors to remove from the extremities.
+
+        Given a set of events, find all those that have been soft-failed or
+        rejected. Returns those soft failed/rejected events and their prev
+        events (whether soft-failed/rejected or not), and recurses up the
+        prev-event graph until it finds no more soft-failed/rejected events.
+
+        This is used to find extremities that are ancestors of new events, but
+        are separated by soft failed events.
+
+        Args:
+            event_ids (Iterable[str]): Events to find prev events for. Note
+                that these must have already been persisted.
+
+        Returns:
+            Deferred[set[str]]
+        """
+
+        # The set of event_ids to return. This includes all soft-failed events
+        # and their prev events.
+        existing_prevs = set()
+
+        def _get_prevs_before_rejected_txn(txn, batch):
+            to_recursively_check = batch
+
+            while to_recursively_check:
+                sql = """
+                SELECT
+                    event_id, prev_event_id, internal_metadata,
+                    rejections.event_id IS NOT NULL
+                FROM event_edges
+                    INNER JOIN events USING (event_id)
+                    LEFT JOIN rejections USING (event_id)
+                    LEFT JOIN event_json USING (event_id)
+                WHERE
+                    NOT events.outlier
+                    AND
+                """
+
+                clause, args = make_in_list_sql_clause(
+                    self.database_engine, "event_id", to_recursively_check
+                )
+
+                txn.execute(sql + clause, args)
+                to_recursively_check = []
+
+                for event_id, prev_event_id, metadata, rejected in txn:
+                    if prev_event_id in existing_prevs:
+                        continue
+
+                    soft_failed = json.loads(metadata).get("soft_failed")
+                    if soft_failed or rejected:
+                        to_recursively_check.append(prev_event_id)
+                        existing_prevs.add(prev_event_id)
+
+        for chunk in batch_iter(event_ids, 100):
+            yield self.db.runInteraction(
+                "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
+            )
+
+        return existing_prevs
+
+    @log_function
+    def _persist_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+        delete_existing: bool = False,
+        state_delta_for_room: Dict[str, DeltaState] = {},
+        new_forward_extremeties: Dict[str, List[str]] = {},
+    ):
+        """Insert some number of room events into the necessary database tables.
+
+        Rejected events are only inserted into the events table, the events_json table,
+        and the rejections table. Things reading from those table will need to check
+        whether the event was rejected.
+
+        Args:
+            txn
+            events_and_contexts: events to persist
+            backfilled: True if the events were backfilled
+            delete_existing True to purge existing table rows for the events
+                from the database. This is useful when retrying due to
+                IntegrityError.
+            state_delta_for_room: The current-state delta for each room.
+            new_forward_extremetie: The new forward extremities for each room.
+                For each room, a list of the event ids which are the forward
+                extremities.
+
+        """
+        all_events_and_contexts = events_and_contexts
+
+        min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
+        max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+
+        self._update_forward_extremities_txn(
+            txn,
+            new_forward_extremities=new_forward_extremeties,
+            max_stream_order=max_stream_order,
+        )
+
+        # Ensure that we don't have the same event twice.
+        events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+            events_and_contexts
+        )
+
+        self._update_room_depths_txn(
+            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
+        )
+
+        # _update_outliers_txn filters out any events which have already been
+        # persisted, and returns the filtered list.
+        events_and_contexts = self._update_outliers_txn(
+            txn, events_and_contexts=events_and_contexts
+        )
+
+        # From this point onwards the events are only events that we haven't
+        # seen before.
+
+        if delete_existing:
+            # For paranoia reasons, we go and delete all the existing entries
+            # for these events so we can reinsert them.
+            # This gets around any problems with some tables already having
+            # entries.
+            self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
+
+        self._store_event_txn(txn, events_and_contexts=events_and_contexts)
+
+        # Insert into event_to_state_groups.
+        self._store_event_state_mappings_txn(txn, events_and_contexts)
+
+        # We want to store event_auth mappings for rejected events, as they're
+        # used in state res v2.
+        # This is only necessary if the rejected event appears in an accepted
+        # event's auth chain, but its easier for now just to store them (and
+        # it doesn't take much storage compared to storing the entire event
+        # anyway).
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_auth",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "auth_id": auth_id,
+                }
+                for event, _ in events_and_contexts
+                for auth_id in event.auth_event_ids()
+                if event.is_state()
+            ],
+        )
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn, events_and_contexts=events_and_contexts
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # We call this last as it assumes we've inserted the events into
+        # room_memberships, where applicable.
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+    def _update_current_state_txn(
+        self,
+        txn: LoggingTransaction,
+        state_delta_by_room: Dict[str, DeltaState],
+        stream_id: int,
+    ):
+        for room_id, delta_state in iteritems(state_delta_by_room):
+            to_delete = delta_state.to_delete
+            to_insert = delta_state.to_insert
+
+            if delta_state.no_longer_in_room:
+                # Server is no longer in the room so we delete the room from
+                # current_state_events, being careful we've already updated the
+                # rooms.room_version column (which gets populated in a
+                # background task).
+                self._upsert_room_version_txn(txn, room_id)
+
+                # Before deleting we populate the current_state_delta_stream
+                # so that async background tasks get told what happened.
+                sql = """
+                    INSERT INTO current_state_delta_stream
+                        (stream_id, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, room_id, type, state_key, null, event_id
+                        FROM current_state_events
+                        WHERE room_id = ?
+                """
+                txn.execute(sql, (stream_id, room_id))
+
+                self.db.simple_delete_txn(
+                    txn, table="current_state_events", keyvalues={"room_id": room_id},
+                )
+            else:
+                # We're still in the room, so we update the current state as normal.
+
+                # First we add entries to the current_state_delta_stream. We
+                # do this before updating the current_state_events table so
+                # that we can use it to calculate the `prev_event_id`. (This
+                # allows us to not have to pull out the existing state
+                # unnecessarily).
+                #
+                # The stream_id for the update is chosen to be the minimum of the stream_ids
+                # for the batch of the events that we are persisting; that means we do not
+                # end up in a situation where workers see events before the
+                # current_state_delta updates.
+                #
+                sql = """
+                    INSERT INTO current_state_delta_stream
+                    (stream_id, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, ?, ?, ?, ?, (
+                        SELECT event_id FROM current_state_events
+                        WHERE room_id = ? AND type = ? AND state_key = ?
+                    )
+                """
+                txn.executemany(
+                    sql,
+                    (
+                        (
+                            stream_id,
+                            room_id,
+                            etype,
+                            state_key,
+                            to_insert.get((etype, state_key)),
+                            room_id,
+                            etype,
+                            state_key,
+                        )
+                        for etype, state_key in itertools.chain(to_delete, to_insert)
+                    ),
+                )
+                # Now we actually update the current_state_events table
+
+                txn.executemany(
+                    "DELETE FROM current_state_events"
+                    " WHERE room_id = ? AND type = ? AND state_key = ?",
+                    (
+                        (room_id, etype, state_key)
+                        for etype, state_key in itertools.chain(to_delete, to_insert)
+                    ),
+                )
+
+                # We include the membership in the current state table, hence we do
+                # a lookup when we insert. This assumes that all events have already
+                # been inserted into room_memberships.
+                txn.executemany(
+                    """INSERT INTO current_state_events
+                        (room_id, type, state_key, event_id, membership)
+                    VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+                    """,
+                    [
+                        (room_id, key[0], key[1], ev_id, ev_id)
+                        for key, ev_id in iteritems(to_insert)
+                    ],
+                )
+
+            # We now update `local_current_membership`. We do this regardless
+            # of whether we're still in the room or not to handle the case where
+            # e.g. we just got banned (where we need to record that fact here).
+
+            # Note: Do we really want to delete rows here (that we do not
+            # subsequently reinsert below)? While technically correct it means
+            # we have no record of the fact the user *was* a member of the
+            # room but got, say, state reset out of it.
+            if to_delete or to_insert:
+                txn.executemany(
+                    "DELETE FROM local_current_membership"
+                    " WHERE room_id = ? AND user_id = ?",
+                    (
+                        (room_id, state_key)
+                        for etype, state_key in itertools.chain(to_delete, to_insert)
+                        if etype == EventTypes.Member and self.is_mine_id(state_key)
+                    ),
+                )
+
+            if to_insert:
+                txn.executemany(
+                    """INSERT INTO local_current_membership
+                        (room_id, user_id, event_id, membership)
+                    VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+                    """,
+                    [
+                        (room_id, key[1], ev_id, ev_id)
+                        for key, ev_id in to_insert.items()
+                        if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+                    ],
+                )
+
+            txn.call_after(
+                self.store._curr_state_delta_stream_cache.entity_has_changed,
+                room_id,
+                stream_id,
+            )
+
+            # Invalidate the various caches
+
+            # Figure out the changes of membership to invalidate the
+            # `get_rooms_for_user` cache.
+            # We find out which membership events we may have deleted
+            # and which we have added, then we invlidate the caches for all
+            # those users.
+            members_changed = {
+                state_key
+                for ev_type, state_key in itertools.chain(to_delete, to_insert)
+                if ev_type == EventTypes.Member
+            }
+
+            for member in members_changed:
+                txn.call_after(
+                    self.store.get_rooms_for_user_with_stream_ordering.invalidate,
+                    (member,),
+                )
+
+            self.store._invalidate_state_caches_and_stream(
+                txn, room_id, members_changed
+            )
+
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+        """Update the room version in the database based off current state
+        events.
+
+        This is used when we're about to delete current state and we want to
+        ensure that the `rooms.room_version` column is up to date.
+        """
+
+        sql = """
+            SELECT json FROM event_json
+            INNER JOIN current_state_events USING (room_id, event_id)
+            WHERE room_id = ? AND type = ? AND state_key = ?
+        """
+        txn.execute(sql, (room_id, EventTypes.Create, ""))
+        row = txn.fetchone()
+        if row:
+            event_json = json.loads(row[0])
+            content = event_json.get("content", {})
+            creator = content.get("creator")
+            room_version_id = content.get("room_version", RoomVersions.V1.identifier)
+
+            self.db.simple_upsert_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                values={"room_version": room_version_id},
+                insertion_values={"is_public": False, "creator": creator},
+            )
+
+    def _update_forward_extremities_txn(
+        self, txn, new_forward_extremities, max_stream_order
+    ):
+        for room_id, new_extrem in iteritems(new_forward_extremities):
+            self.db.simple_delete_txn(
+                txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
+            )
+            txn.call_after(
+                self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
+            )
+
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_forward_extremities",
+            values=[
+                {"event_id": ev_id, "room_id": room_id}
+                for room_id, new_extrem in iteritems(new_forward_extremities)
+                for ev_id in new_extrem
+            ],
+        )
+        # We now insert into stream_ordering_to_exterm a mapping from room_id,
+        # new stream_ordering to new forward extremeties in the room.
+        # This allows us to later efficiently look up the forward extremeties
+        # for a room before a given stream_ordering
+        self.db.simple_insert_many_txn(
+            txn,
+            table="stream_ordering_to_exterm",
+            values=[
+                {
+                    "room_id": room_id,
+                    "event_id": event_id,
+                    "stream_ordering": max_stream_order,
+                }
+                for room_id, new_extrem in iteritems(new_forward_extremities)
+                for event_id in new_extrem
+            ],
+        )
+
+    @classmethod
+    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+        """Ensure that we don't have the same event twice.
+
+        Pick the earliest non-outlier if there is one, else the earliest one.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+        Returns:
+            list[(EventBase, EventContext)]: filtered list
+        """
+        new_events_and_contexts = OrderedDict()
+        for event, context in events_and_contexts:
+            prev_event_context = new_events_and_contexts.get(event.event_id)
+            if prev_event_context:
+                if not event.internal_metadata.is_outlier():
+                    if prev_event_context[0].internal_metadata.is_outlier():
+                        # To ensure correct ordering we pop, as OrderedDict is
+                        # ordered by first insertion.
+                        new_events_and_contexts.pop(event.event_id, None)
+                        new_events_and_contexts[event.event_id] = (event, context)
+            else:
+                new_events_and_contexts[event.event_id] = (event, context)
+        return list(new_events_and_contexts.values())
+
+    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+        """Update min_depth for each room
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
+        depth_updates = {}
+        for event, context in events_and_contexts:
+            # Remove the any existing cache entries for the event_ids
+            txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+            if not backfilled:
+                txn.call_after(
+                    self.store._events_stream_cache.entity_has_changed,
+                    event.room_id,
+                    event.internal_metadata.stream_ordering,
+                )
+
+            if not event.internal_metadata.is_outlier() and not context.rejected:
+                depth_updates[event.room_id] = max(
+                    event.depth, depth_updates.get(event.room_id, event.depth)
+                )
+
+        for room_id, depth in iteritems(depth_updates):
+            self._update_min_depth_for_room_txn(txn, room_id, depth)
+
+    def _update_outliers_txn(self, txn, events_and_contexts):
+        """Update any outliers with new event info.
+
+        This turns outliers into ex-outliers (unless the new event was
+        rejected).
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without events which
+            are already in the events table.
+        """
+        txn.execute(
+            "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
+            % (",".join(["?"] * len(events_and_contexts)),),
+            [event.event_id for event, _ in events_and_contexts],
+        )
+
+        have_persisted = {event_id: outlier for event_id, outlier in txn}
+
+        to_remove = set()
+        for event, context in events_and_contexts:
+            if event.event_id not in have_persisted:
+                continue
+
+            to_remove.add(event)
+
+            if context.rejected:
+                # If the event is rejected then we don't care if the event
+                # was an outlier or not.
+                continue
+
+            outlier_persisted = have_persisted[event.event_id]
+            if not event.internal_metadata.is_outlier() and outlier_persisted:
+                # We received a copy of an event that we had already stored as
+                # an outlier in the database. We now have some state at that
+                # so we need to update the state_groups table with that state.
+
+                # insert into event_to_state_groups.
+                try:
+                    self._store_event_state_mappings_txn(txn, ((event, context),))
+                except Exception:
+                    logger.exception("")
+                    raise
+
+                metadata_json = encode_json(event.internal_metadata.get_dict())
+
+                sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
+                txn.execute(sql, (metadata_json, event.event_id))
+
+                # Add an entry to the ex_outlier_stream table to replicate the
+                # change in outlier status to our workers.
+                stream_order = event.internal_metadata.stream_ordering
+                state_group_id = context.state_group
+                self.db.simple_insert_txn(
+                    txn,
+                    table="ex_outlier_stream",
+                    values={
+                        "event_stream_ordering": stream_order,
+                        "event_id": event.event_id,
+                        "state_group": state_group_id,
+                    },
+                )
+
+                sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
+                txn.execute(sql, (False, event.event_id))
+
+                # Update the event_backward_extremities table now that this
+                # event isn't an outlier any more.
+                self._update_backward_extremeties(txn, [event])
+
+        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
+
+    @classmethod
+    def _delete_existing_rows_txn(cls, txn, events_and_contexts):
+        if not events_and_contexts:
+            # nothing to do here
+            return
+
+        logger.info("Deleting existing")
+
+        for table in (
+            "events",
+            "event_auth",
+            "event_json",
+            "event_edges",
+            "event_forward_extremities",
+            "event_reference_hashes",
+            "event_search",
+            "event_to_state_groups",
+            "local_invites",
+            "state_events",
+            "rejections",
+            "redactions",
+            "room_memberships",
+        ):
+            txn.executemany(
+                "DELETE FROM %s WHERE event_id = ?" % (table,),
+                [(ev.event_id,) for ev, _ in events_and_contexts],
+            )
+
+        for table in ("event_push_actions",):
+            txn.executemany(
+                "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
+                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
+            )
+
+    def _store_event_txn(self, txn, events_and_contexts):
+        """Insert new events into the event and event_json tables
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+        """
+
+        if not events_and_contexts:
+            # nothing to do here
+            return
+
+        def event_dict(event):
+            d = event.get_dict()
+            d.pop("redacted", None)
+            d.pop("redacted_because", None)
+            return d
+
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_json",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "internal_metadata": encode_json(
+                        event.internal_metadata.get_dict()
+                    ),
+                    "json": encode_json(event_dict(event)),
+                    "format_version": event.format_version,
+                }
+                for event, _ in events_and_contexts
+            ],
+        )
+
+        self.db.simple_insert_many_txn(
+            txn,
+            table="events",
+            values=[
+                {
+                    "stream_ordering": event.internal_metadata.stream_ordering,
+                    "topological_ordering": event.depth,
+                    "depth": event.depth,
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "type": event.type,
+                    "processed": True,
+                    "outlier": event.internal_metadata.is_outlier(),
+                    "origin_server_ts": int(event.origin_server_ts),
+                    "received_ts": self._clock.time_msec(),
+                    "sender": event.sender,
+                    "contains_url": (
+                        "url" in event.content
+                        and isinstance(event.content["url"], text_type)
+                    ),
+                }
+                for event, _ in events_and_contexts
+            ],
+        )
+
+        for event, _ in events_and_contexts:
+            if not event.internal_metadata.is_redacted():
+                # If we're persisting an unredacted event we go and ensure
+                # that we mark any redactions that reference this event as
+                # requiring censoring.
+                self.db.simple_update_txn(
+                    txn,
+                    table="redactions",
+                    keyvalues={"redacts": event.event_id},
+                    updatevalues={"have_censored": False},
+                )
+
+    def _store_rejected_events_txn(self, txn, events_and_contexts):
+        """Add rows to the 'rejections' table for received events which were
+        rejected
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without the rejected
+                events.
+        """
+        # Remove the rejected events from the list now that we've added them
+        # to the events table and the events_json table.
+        to_remove = set()
+        for event, context in events_and_contexts:
+            if context.rejected:
+                # Insert the event_id into the rejections table
+                self._store_rejections_txn(txn, event.event_id, context.rejected)
+                to_remove.add(event)
+
+        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
+
+    def _update_metadata_tables_txn(
+        self, txn, events_and_contexts, all_events_and_contexts, backfilled
+    ):
+        """Update all the miscellaneous tables for new events
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+            backfilled (bool): True if the events were backfilled
+        """
+
+        # Insert all the push actions into the event_push_actions table.
+        self._set_push_actions_for_event_and_users_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+        )
+
+        if not events_and_contexts:
+            # nothing to do here
+            return
+
+        for event, context in events_and_contexts:
+            if event.type == EventTypes.Redaction and event.redacts is not None:
+                # Remove the entries in the event_push_actions table for the
+                # redacted event.
+                self._remove_push_actions_for_event_id_txn(
+                    txn, event.room_id, event.redacts
+                )
+
+                # Remove from relations table.
+                self._handle_redaction(txn, event.redacts)
+
+        # Update the event_forward_extremities, event_backward_extremities and
+        # event_edges tables.
+        self._handle_mult_prev_events(
+            txn, events=[event for event, _ in events_and_contexts]
+        )
+
+        for event, _ in events_and_contexts:
+            if event.type == EventTypes.Name:
+                # Insert into the event_search table.
+                self._store_room_name_txn(txn, event)
+            elif event.type == EventTypes.Topic:
+                # Insert into the event_search table.
+                self._store_room_topic_txn(txn, event)
+            elif event.type == EventTypes.Message:
+                # Insert into the event_search table.
+                self._store_room_message_txn(txn, event)
+            elif event.type == EventTypes.Redaction and event.redacts is not None:
+                # Insert into the redactions table.
+                self._store_redaction(txn, event)
+            elif event.type == EventTypes.Retention:
+                # Update the room_retention table.
+                self._store_retention_policy_for_room_txn(txn, event)
+
+            self._handle_event_relations(txn, event)
+
+            # Store the labels for this event.
+            labels = event.content.get(EventContentFields.LABELS)
+            if labels:
+                self.insert_labels_for_event_txn(
+                    txn, event.event_id, labels, event.room_id, event.depth
+                )
+
+            if self._ephemeral_messages_enabled:
+                # If there's an expiry timestamp on the event, store it.
+                expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+                if isinstance(expiry_ts, int) and not event.is_state():
+                    self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
+        # Insert into the room_memberships table.
+        self._store_room_members_txn(
+            txn,
+            [
+                event
+                for event, _ in events_and_contexts
+                if event.type == EventTypes.Member
+            ],
+            backfilled=backfilled,
+        )
+
+        # Insert event_reference_hashes table.
+        self._store_event_reference_hashes_txn(
+            txn, [event for event, _ in events_and_contexts]
+        )
+
+        state_events_and_contexts = [
+            ec for ec in events_and_contexts if ec[0].is_state()
+        ]
+
+        state_values = []
+        for event, context in state_events_and_contexts:
+            vals = {
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "type": event.type,
+                "state_key": event.state_key,
+            }
+
+            # TODO: How does this work with backfilling?
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
+
+            state_values.append(vals)
+
+        self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
+
+        # Prefill the event cache
+        self._add_to_cache(txn, events_and_contexts)
+
+    def _add_to_cache(self, txn, events_and_contexts):
+        to_prefill = []
+
+        rows = []
+        N = 200
+        for i in range(0, len(events_and_contexts), N):
+            ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]}
+            if not ev_map:
+                break
+
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
+                " FROM events as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE "
+            )
+
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "e.event_id", list(ev_map)
+            )
+
+            txn.execute(sql + clause, args)
+            rows = self.db.cursor_to_dict(txn)
+            for row in rows:
+                event = ev_map[row["event_id"]]
+                if not row["rejects"] and not row["redacts"]:
+                    to_prefill.append(
+                        _EventCacheEntry(event=event, redacted_event=None)
+                    )
+
+        def prefill():
+            for cache_entry in to_prefill:
+                self.store._get_event_cache.prefill(
+                    (cache_entry[0].event_id,), cache_entry
+                )
+
+        txn.call_after(prefill)
+
+    def _store_redaction(self, txn, event):
+        # invalidate the cache for the redacted event
+        txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+
+        self.db.simple_insert_txn(
+            txn,
+            table="redactions",
+            values={
+                "event_id": event.event_id,
+                "redacts": event.redacts,
+                "received_ts": self._clock.time_msec(),
+            },
+        )
+
+    def insert_labels_for_event_txn(
+        self, txn, event_id, labels, room_id, topological_ordering
+    ):
+        """Store the mapping between an event's ID and its labels, with one row per
+        (event_id, label) tuple.
+
+        Args:
+            txn (LoggingTransaction): The transaction to execute.
+            event_id (str): The event's ID.
+            labels (list[str]): A list of text labels.
+            room_id (str): The ID of the room the event was sent to.
+            topological_ordering (int): The position of the event in the room's topology.
+        """
+        return self.db.simple_insert_many_txn(
+            txn=txn,
+            table="event_labels",
+            values=[
+                {
+                    "event_id": event_id,
+                    "label": label,
+                    "room_id": room_id,
+                    "topological_ordering": topological_ordering,
+                }
+                for label in labels
+            ],
+        )
+
+    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+        """Save the expiry timestamp associated with a given event ID.
+
+        Args:
+            txn (LoggingTransaction): The database transaction to use.
+            event_id (str): The event ID the expiry timestamp is associated with.
+            expiry_ts (int): The timestamp at which to expire (delete) the event.
+        """
+        return self.db.simple_insert_txn(
+            txn=txn,
+            table="event_expiry",
+            values={"event_id": event_id, "expiry_ts": expiry_ts},
+        )
+
+    def _store_event_reference_hashes_txn(self, txn, events):
+        """Store a hash for a PDU
+        Args:
+            txn (cursor):
+            events (list): list of Events.
+        """
+
+        vals = []
+        for event in events:
+            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+            vals.append(
+                {
+                    "event_id": event.event_id,
+                    "algorithm": ref_alg,
+                    "hash": memoryview(ref_hash_bytes),
+                }
+            )
+
+        self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
+        """
+        self.db.simple_insert_many_txn(
+            txn,
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ],
+        )
+
+        for event in events:
+            txn.call_after(
+                self.store._membership_stream_cache.entity_has_changed,
+                event.state_key,
+                event.internal_metadata.stream_ordering,
+            )
+            txn.call_after(
+                self.store.get_invited_rooms_for_local_user.invalidate,
+                (event.state_key,),
+            )
+
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened. If the event is an
+            # outlier it is only current if its an "out of band membership",
+            # like a remote invite or a rejection of a remote invite.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_out_of_band_membership()
+            )
+            is_mine = self.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self.db.simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        },
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
+
+                    txn.execute(
+                        sql,
+                        (
+                            event.internal_metadata.stream_ordering,
+                            event.event_id,
+                            event.room_id,
+                            event.state_key,
+                        ),
+                    )
+
+                # We also update the `local_current_membership` table with
+                # latest invite info. This will usually get updated by the
+                # `current_state_events` handling, unless its an outlier.
+                if event.internal_metadata.is_outlier():
+                    # This should only happen for out of band memberships, so
+                    # we add a paranoia check.
+                    assert event.internal_metadata.is_out_of_band_membership()
+
+                    self.db.simple_upsert_txn(
+                        txn,
+                        table="local_current_membership",
+                        keyvalues={
+                            "room_id": event.room_id,
+                            "user_id": event.state_key,
+                        },
+                        values={
+                            "event_id": event.event_id,
+                            "membership": event.membership,
+                        },
+                    )
+
+    def _handle_event_relations(self, txn, event):
+        """Handles inserting relation data during peristence of events
+
+        Args:
+            txn
+            event (EventBase)
+        """
+        relation = event.content.get("m.relates_to")
+        if not relation:
+            # No relations
+            return
+
+        rel_type = relation.get("rel_type")
+        if rel_type not in (
+            RelationTypes.ANNOTATION,
+            RelationTypes.REFERENCE,
+            RelationTypes.REPLACE,
+        ):
+            # Unknown relation type
+            return
+
+        parent_id = relation.get("event_id")
+        if not parent_id:
+            # Invalid relation
+            return
+
+        aggregation_key = relation.get("key")
+
+        self.db.simple_insert_txn(
+            txn,
+            table="event_relations",
+            values={
+                "event_id": event.event_id,
+                "relates_to_id": parent_id,
+                "relation_type": rel_type,
+                "aggregation_key": aggregation_key,
+            },
+        )
+
+        txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+        )
+
+        if rel_type == RelationTypes.REPLACE:
+            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+
+    def _handle_redaction(self, txn, redacted_event_id):
+        """Handles receiving a redaction and checking whether we need to remove
+        any redacted relations from the database.
+
+        Args:
+            txn
+            redacted_event_id (str): The event that was redacted.
+        """
+
+        self.db.simple_delete_txn(
+            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
+        )
+
+    def _store_room_topic_txn(self, txn, event):
+        if hasattr(event, "content") and "topic" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.topic", event.content["topic"]
+            )
+
+    def _store_room_name_txn(self, txn, event):
+        if hasattr(event, "content") and "name" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.name", event.content["name"]
+            )
+
+    def _store_room_message_txn(self, txn, event):
+        if hasattr(event, "content") and "body" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.body", event.content["body"]
+            )
+
+    def _store_retention_policy_for_room_txn(self, txn, event):
+        if hasattr(event, "content") and (
+            "min_lifetime" in event.content or "max_lifetime" in event.content
+        ):
+            if (
+                "min_lifetime" in event.content
+                and not isinstance(event.content.get("min_lifetime"), integer_types)
+            ) or (
+                "max_lifetime" in event.content
+                and not isinstance(event.content.get("max_lifetime"), integer_types)
+            ):
+                # Ignore the event if one of the value isn't an integer.
+                return
+
+            self.db.simple_insert_txn(
+                txn=txn,
+                table="room_retention",
+                values={
+                    "room_id": event.room_id,
+                    "event_id": event.event_id,
+                    "min_lifetime": event.content.get("min_lifetime"),
+                    "max_lifetime": event.content.get("max_lifetime"),
+                },
+            )
+
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_retention_policy_for_room, (event.room_id,)
+            )
+
+    def store_event_search_txn(self, txn, event, key, value):
+        """Add event to the search table
+
+        Args:
+            txn (cursor):
+            event (EventBase):
+            key (str):
+            value (str):
+        """
+        self.store.store_search_entries_txn(
+            txn,
+            (
+                SearchEntry(
+                    key=key,
+                    value=value,
+                    event_id=event.event_id,
+                    room_id=event.room_id,
+                    stream_ordering=event.internal_metadata.stream_ordering,
+                    origin_server_ts=event.origin_server_ts,
+                ),
+            ),
+        )
+
+    def _set_push_actions_for_event_and_users_txn(
+        self, txn, events_and_contexts, all_events_and_contexts
+    ):
+        """Handles moving push actions from staging table to main
+        event_push_actions table for all events in `events_and_contexts`.
+
+        Also ensures that all events in `all_events_and_contexts` are removed
+        from the push action staging area.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+        """
+
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
+            )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
+
+        if events_and_contexts:
+            txn.executemany(
+                sql,
+                (
+                    (
+                        event.room_id,
+                        event.internal_metadata.stream_ordering,
+                        event.depth,
+                        event.event_id,
+                    )
+                    for event, _ in events_and_contexts
+                ),
+            )
+
+        for event, _ in events_and_contexts:
+            user_ids = self.db.simple_select_onecol_txn(
+                txn,
+                table="event_push_actions_staging",
+                keyvalues={"event_id": event.event_id},
+                retcol="user_id",
+            )
+
+            for uid in user_ids:
+                txn.call_after(
+                    self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                    (event.room_id, uid),
+                )
+
+        # Now we delete the staging area for *all* events that were being
+        # persisted.
+        txn.executemany(
+            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+            ((event.event_id,) for event, _ in all_events_and_contexts),
+        )
+
+    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+        # Sad that we have to blow away the cache for the whole room here
+        txn.call_after(
+            self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id,),
+        )
+        txn.execute(
+            "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+            (room_id, event_id),
+        )
+
+    def _store_rejections_txn(self, txn, event_id, reason):
+        self.db.simple_insert_txn(
+            txn,
+            table="rejections",
+            values={
+                "event_id": event_id,
+                "reason": reason,
+                "last_check": self._clock.time_msec(),
+            },
+        )
+
+    def _store_event_state_mappings_txn(
+        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+    ):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.state_group_before_event
+                continue
+
+            state_groups[event.event_id] = context.state_group
+
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_to_state_groups",
+            values=[
+                {"state_group": state_group_id, "event_id": event_id}
+                for event_id, state_group_id in iteritems(state_groups)
+            ],
+        )
+
+        for event_id, state_group_id in iteritems(state_groups):
+            txn.call_after(
+                self.store._get_state_group_for_event.prefill,
+                (event_id,),
+                state_group_id,
+            )
+
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+        min_depth = self.store._get_min_depth_interaction(txn, room_id)
+
+        if min_depth is not None and depth >= min_depth:
+            return
+
+        self.db.simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={"room_id": room_id},
+            values={"min_depth": depth},
+        )
+
+    def _handle_mult_prev_events(self, txn, events):
+        """
+        For the given event, update the event edges table and forward and
+        backward extremities tables.
+        """
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_edges",
+            values=[
+                {
+                    "event_id": ev.event_id,
+                    "prev_event_id": e_id,
+                    "room_id": ev.room_id,
+                    "is_state": False,
+                }
+                for ev in events
+                for e_id in ev.prev_event_ids()
+            ],
+        )
+
+        self._update_backward_extremeties(txn, events)
+
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
+        events being persisted.
+
+        This is called for new events *and* for events that were outliers, but
+        are now being persisted as non-outliers.
+
+        Forward extremities are handled when we first start persisting the events.
+        """
+        events_by_room = {}
+        for ev in events:
+            events_by_room.setdefault(ev.room_id, []).append(ev)
+
+        query = (
+            "INSERT INTO event_backward_extremities (event_id, room_id)"
+            " SELECT ?, ? WHERE NOT EXISTS ("
+            " SELECT 1 FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+            " )"
+            " AND NOT EXISTS ("
+            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+            " AND outlier = ?"
+            " )"
+        )
+
+        txn.executemany(
+            query,
+            [
+                (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+                for ev in events
+                for e_id in ev.prev_event_ids()
+                if not ev.internal_metadata.is_outlier()
+            ],
+        )
+
+        query = (
+            "DELETE FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+        )
+        txn.executemany(
+            query,
+            [
+                (ev.event_id, ev.room_id)
+                for ev in events
+                if not ev.internal_metadata.is_outlier()
+            ],
+        )
+
+    async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+        """Mark the invite has having been rejected even though we failed to
+        create a leave event for it.
+        """
+
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
+
+        def f(txn, stream_ordering):
+            txn.execute(sql, (stream_ordering, True, room_id, user_id))
+
+            # We also clear this entry from `local_current_membership`.
+            # Ideally we'd point to a leave event, but we don't have one, so
+            # nevermind.
+            self.db.simple_delete_txn(
+                txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+            )
+
+        with self._stream_id_gen.get_next() as stream_ordering:
+            await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
+
+        return stream_ordering
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 6587f31e2b..f54c8b1ee0 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -21,29 +21,31 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.api.constants import EventContentFields
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
 
-class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
+class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
     DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
 
-    def __init__(self, db_conn, hs):
-        super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
             self._background_reindex_fields_sender,
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_contains_url_index",
             index_name="event_contains_url_index",
             table="events",
@@ -54,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
         # an event_id index on event_search is useful for the purge_history
         # api. Plus it means we get to enforce some integrity with a UNIQUE
         # clause
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_search_event_id_idx",
             index_name="event_search_event_id_idx",
             table="event_search",
@@ -63,10 +65,39 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
             psql_only=True,
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
         )
 
+        self.db.updates.register_background_update_handler(
+            "redactions_received_ts", self._redactions_received_ts
+        )
+
+        # This index gets deleted in `event_fix_redactions_bytes` update
+        self.db.updates.register_background_index_update(
+            "event_fix_redactions_bytes_create_index",
+            index_name="redactions_censored_redacts",
+            table="redactions",
+            columns=["redacts"],
+            where_clause="have_censored",
+        )
+
+        self.db.updates.register_background_update_handler(
+            "event_fix_redactions_bytes", self._event_fix_redactions_bytes
+        )
+
+        self.db.updates.register_background_update_handler(
+            "event_store_labels", self._event_store_labels
+        )
+
+        self.db.updates.register_background_index_update(
+            "redactions_have_censored_ts_idx",
+            index_name="redactions_have_censored_ts",
+            table="redactions",
+            columns=["received_ts"],
+            where_clause="NOT have_censored",
+        )
+
     @defer.inlineCallbacks
     def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -122,18 +153,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(rows),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
             )
 
             return len(rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
+            )
 
         return result
 
@@ -166,7 +199,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
             for chunk in chunks:
-                ev_rows = self._simple_select_many_txn(
+                ev_rows = self.db.simple_select_many_txn(
                     txn,
                     table="event_json",
                     column="event_id",
@@ -199,18 +232,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(rows_to_update),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
             )
 
             return len(rows_to_update)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_ORIGIN_SERVER_TS_NAME
+            )
 
         return result
 
@@ -308,12 +343,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                     INNER JOIN event_json USING (event_id)
                     LEFT JOIN rejections USING (event_id)
                     WHERE
-                        prev_event_id IN (%s)
-                        AND NOT events.outlier
-                """ % (
-                    ",".join("?" for _ in to_check),
+                        NOT events.outlier
+                        AND
+                """
+                clause, args = make_in_list_sql_clause(
+                    self.database_engine, "prev_event_id", to_check
                 )
-                txn.execute(sql, to_check)
+                txn.execute(sql + clause, list(args))
 
                 for prev_event_id, event_id, metadata, rejected in txn:
                     if event_id in graph:
@@ -342,7 +378,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             to_delete.intersection_update(original_set)
 
-            deleted = self._simple_delete_many_txn(
+            deleted = self.db.simple_delete_many_txn(
                 txn=txn,
                 table="event_forward_extremities",
                 column="event_id",
@@ -358,7 +394,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             if deleted:
                 # We now need to invalidate the caches of these rooms
-                rows = self._simple_select_many_txn(
+                rows = self.db.simple_select_many_txn(
                     txn,
                     table="events",
                     column="event_id",
@@ -366,13 +402,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                     keyvalues={},
                     retcols=("room_id",),
                 )
-                room_ids = set(row["room_id"] for row in rows)
+                room_ids = {row["room_id"] for row in rows}
                 for room_id in room_ids:
                     txn.call_after(
                         self.get_latest_event_ids_in_room.invalidate, (room_id,)
                     )
 
-            self._simple_delete_many_txn(
+            self.db.simple_delete_many_txn(
                 txn=txn,
                 table="_extremities_to_check",
                 column="event_id",
@@ -382,18 +418,172 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             return len(original_set)
 
-        num_handled = yield self.runInteraction(
+        num_handled = yield self.db.runInteraction(
             "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
         )
 
         if not num_handled:
-            yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
+            yield self.db.updates._end_background_update(
+                self.DELETE_SOFT_FAILED_EXTREMITIES
+            )
 
             def _drop_table_txn(txn):
                 txn.execute("DROP TABLE _extremities_to_check")
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
             )
 
         return num_handled
+
+    @defer.inlineCallbacks
+    def _redactions_received_ts(self, progress, batch_size):
+        """Handles filling out the `received_ts` column in redactions.
+        """
+        last_event_id = progress.get("last_event_id", "")
+
+        def _redactions_received_ts_txn(txn):
+            # Fetch the set of event IDs that we want to update
+            sql = """
+                SELECT event_id FROM redactions
+                WHERE event_id > ?
+                ORDER BY event_id ASC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_event_id, batch_size))
+
+            rows = txn.fetchall()
+            if not rows:
+                return 0
+
+            (upper_event_id,) = rows[-1]
+
+            # Update the redactions with the received_ts.
+            #
+            # Note: Not all events have an associated received_ts, so we
+            # fallback to using origin_server_ts. If we for some reason don't
+            # have an origin_server_ts, lets just use the current timestamp.
+            #
+            # We don't want to leave it null, as then we'll never try and
+            # censor those redactions.
+            sql = """
+                UPDATE redactions
+                SET received_ts = (
+                    SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events
+                    WHERE events.event_id = redactions.event_id
+                )
+                WHERE ? <= event_id AND event_id <= ?
+            """
+
+            txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
+
+            self.db.updates._background_update_progress_txn(
+                txn, "redactions_received_ts", {"last_event_id": upper_event_id}
+            )
+
+            return len(rows)
+
+        count = yield self.db.runInteraction(
+            "_redactions_received_ts", _redactions_received_ts_txn
+        )
+
+        if not count:
+            yield self.db.updates._end_background_update("redactions_received_ts")
+
+        return count
+
+    @defer.inlineCallbacks
+    def _event_fix_redactions_bytes(self, progress, batch_size):
+        """Undoes hex encoded censored redacted event JSON.
+        """
+
+        def _event_fix_redactions_bytes_txn(txn):
+            # This update is quite fast due to new index.
+            txn.execute(
+                """
+                UPDATE event_json
+                SET
+                    json = convert_from(json::bytea, 'utf8')
+                FROM redactions
+                WHERE
+                    redactions.have_censored
+                    AND event_json.event_id = redactions.redacts
+                    AND json NOT LIKE '{%';
+                """
+            )
+
+            txn.execute("DROP INDEX redactions_censored_redacts")
+
+        yield self.db.runInteraction(
+            "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
+        )
+
+        yield self.db.updates._end_background_update("event_fix_redactions_bytes")
+
+        return 1
+
+    @defer.inlineCallbacks
+    def _event_store_labels(self, progress, batch_size):
+        """Background update handler which will store labels for existing events."""
+        last_event_id = progress.get("last_event_id", "")
+
+        def _event_store_labels_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, json FROM event_json
+                LEFT JOIN event_labels USING (event_id)
+                WHERE event_id > ? AND label IS NULL
+                ORDER BY event_id LIMIT ?
+                """,
+                (last_event_id, batch_size),
+            )
+
+            results = list(txn)
+
+            nbrows = 0
+            last_row_event_id = ""
+            for (event_id, event_json_raw) in results:
+                try:
+                    event_json = json.loads(event_json_raw)
+
+                    self.db.simple_insert_many_txn(
+                        txn=txn,
+                        table="event_labels",
+                        values=[
+                            {
+                                "event_id": event_id,
+                                "label": label,
+                                "room_id": event_json["room_id"],
+                                "topological_ordering": event_json["depth"],
+                            }
+                            for label in event_json["content"].get(
+                                EventContentFields.LABELS, []
+                            )
+                            if isinstance(label, str)
+                        ],
+                    )
+                except Exception as e:
+                    logger.warning(
+                        "Unable to load event %s (no labels will be imported): %s",
+                        event_id,
+                        e,
+                    )
+
+                nbrows += 1
+                last_row_event_id = event_id
+
+            self.db.updates._background_update_progress_txn(
+                txn, "event_store_labels", {"last_event_id": last_row_event_id}
+            )
+
+            return nbrows
+
+        num_rows = yield self.db.runInteraction(
+            desc="event_store_labels", func=_event_store_labels_txn
+        )
+
+        if not num_rows:
+            yield self.db.updates._end_background_update("event_store_labels")
+
+        return num_rows
diff --git a/synapse/storage/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index c6fa7f82fd..213d69100a 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -17,26 +17,35 @@ from __future__ import division
 
 import itertools
 import logging
+import threading
 from collections import namedtuple
+from typing import List, Optional, Tuple
 
 from canonicaljson import json
+from constantly import NamedConstant, Names
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
-from synapse.api.room_versions import EventFormatVersions
-from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.api.errors import NotFoundError, SynapseError
+from synapse.api.room_versions import (
+    KNOWN_ROOM_VERSIONS,
+    EventFormatVersions,
+    RoomVersions,
+)
+from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
+from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
-from ._base import SQLBaseStore
-
 logger = logging.getLogger(__name__)
 
 
@@ -53,7 +62,64 @@ EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
+class EventRedactBehaviour(Names):
+    """
+    What to do when retrieving a redacted event from the database.
+    """
+
+    AS_IS = NamedConstant()
+    REDACT = NamedConstant()
+    BLOCK = NamedConstant()
+
+
 class EventsWorkerStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+
+        if hs.config.worker.writers.events == hs.get_instance_name():
+            # We are the process in charge of generating stream ids for events,
+            # so instantiate ID generators based on the database
+            self._stream_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                extra_tables=[("local_invites", "stream_id")],
+            )
+            self._backfill_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                step=-1,
+                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+            )
+        else:
+            # Another process is in charge of persisting events and generating
+            # stream IDs: rely on the replication streams to let us know which
+            # IDs we can process.
+            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+            self._backfill_id_gen = SlavedIdTracker(
+                db_conn, "events", "stream_ordering", step=-1
+            )
+
+        self._get_event_cache = Cache(
+            "*getEvent*",
+            keylen=3,
+            max_entries=hs.config.caches.event_cache_size,
+            apply_cache_factor_from_config=False,
+        )
+
+        self._event_fetch_lock = threading.Condition()
+        self._event_fetch_list = []
+        self._event_fetch_ongoing = 0
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "events":
+            self._stream_id_gen.advance(token)
+        elif stream_name == "backfill":
+            self._backfill_id_gen.advance(-token)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
     def get_received_ts(self, event_id):
         """Get received_ts (when it was persisted) for the event.
 
@@ -66,7 +132,7 @@ class EventsWorkerStore(SQLBaseStore):
             Deferred[int|None]: Timestamp in milliseconds, or None for events
             that were persisted before received_ts was implemented.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="events",
             keyvalues={"event_id": event_id},
             retcol="received_ts",
@@ -105,32 +171,41 @@ class EventsWorkerStore(SQLBaseStore):
 
             return ts
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_approximate_received_ts", _get_approximate_received_ts_txn
         )
 
     @defer.inlineCallbacks
     def get_event(
         self,
-        event_id,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
-        allow_none=False,
-        check_room_id=None,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: bool = False,
+        check_room_id: Optional[str] = None,
     ):
         """Get an event from the database by event_id.
 
         Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_id: The event_id of the event to fetch
+
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events (behave as per allow_none
+                    if the event is redacted)
+
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
+
+            allow_rejected: If True, return rejected events. Otherwise,
+                behave as per allow_none.
+
+            allow_none: If True, return None if no event found, if
                 False throw a NotFoundError
-            check_room_id (str|None): if not None, check the room of the found event.
+
+            check_room_id: if not None, check the room of the found event.
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
@@ -141,7 +216,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         events = yield self.get_events_as_list(
             [event_id],
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -160,27 +235,34 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+
+            redact_behaviour: Determine what to do with a redacted event. Possible
+                values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events (omit them from the response)
+
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+
+            allow_rejected: If True, return rejected events. Otherwise,
+                omits rejeted events from the response.
 
         Returns:
             Deferred : Dict from event_id to event.
         """
         events = yield self.get_events_as_list(
             event_ids,
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -190,21 +272,29 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events_as_list(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
+        Unknown events will be omitted from the response.
+
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events (omit them from the response)
+
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+
+            allow_rejected: If True, return rejected events. Otherwise,
+                omits rejected events from the response.
 
         Returns:
             Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -238,6 +328,20 @@ class EventsWorkerStore(SQLBaseStore):
             # we have to recheck auth now.
 
             if not allow_rejected and entry.event.type == EventTypes.Redaction:
+                if entry.event.redacts is None:
+                    # A redacted redaction doesn't have a `redacts` key, in
+                    # which case lets just withhold the event.
+                    #
+                    # Note: Most of the time if the redactions has been
+                    # redacted we still have the un-redacted event in the DB
+                    # and so we'll still see the `redacts` key. However, this
+                    # isn't always true e.g. if we have censored the event.
+                    logger.debug(
+                        "Withholding redaction event %s as we don't have redacts key",
+                        event_id,
+                    )
+                    continue
+
                 redacted_event_id = entry.event.redacts
                 event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
                 original_event_entry = event_map.get(redacted_event_id)
@@ -292,10 +396,14 @@ class EventsWorkerStore(SQLBaseStore):
                     # Update the cache to save doing the checks again.
                     entry.event.internal_metadata.recheck_redaction = False
 
-            if check_redacted and entry.redacted_event:
-                event = entry.redacted_event
-            else:
-                event = entry.event
+            event = entry.event
+
+            if entry.redacted_event:
+                if redact_behaviour == EventRedactBehaviour.BLOCK:
+                    # Skip this event
+                    continue
+                elif redact_behaviour == EventRedactBehaviour.REDACT:
+                    event = entry.redacted_event
 
             events.append(event)
 
@@ -319,9 +427,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         If events are pulled from the database, they will be cached for future lookups.
 
+        Unknown events are omitted from the response.
+
         Args:
+
             event_ids (Iterable[str]): The event_ids of the events to fetch
-            allow_rejected (bool): Whether to include rejected events
+
+            allow_rejected (bool): Whether to include rejected events. If False,
+                rejected events are omitted from the response.
 
         Returns:
             Deferred[Dict[str, _EventCacheEntry]]:
@@ -334,7 +447,7 @@ class EventsWorkerStore(SQLBaseStore):
         missing_events_ids = [e for e in event_ids if e not in event_entry_map]
 
         if missing_events_ids:
-            log_ctx = LoggingContext.current_context()
+            log_ctx = current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
             # Note that _get_events_from_db is also responsible for turning db rows
@@ -422,11 +535,11 @@ class EventsWorkerStore(SQLBaseStore):
         """
         with Measure(self._clock, "_fetch_event_list"):
             try:
-                events_to_fetch = set(
+                events_to_fetch = {
                     event_id for events, _ in event_list for event_id in events
-                )
+                }
 
-                row_dict = self._new_transaction(
+                row_dict = self.db.new_transaction(
                     conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
                 )
 
@@ -456,9 +569,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         Returned events will be added to the cache for future lookups.
 
+        Unknown events are omitted from the response.
+
         Args:
             event_ids (Iterable[str]): The event_ids of the events to fetch
-            allow_rejected (bool): Whether to include rejected events
+
+            allow_rejected (bool): Whether to include rejected events. If False,
+                rejected events are omitted from the response.
 
         Returns:
             Deferred[Dict[str, _EventCacheEntry]]:
@@ -504,15 +621,56 @@ class EventsWorkerStore(SQLBaseStore):
                 # of a event format version, so it must be a V1 event.
                 format_version = EventFormatVersions.V1
 
-            original_ev = event_type_from_format_version(format_version)(
+            room_version_id = row["room_version_id"]
+
+            if not room_version_id:
+                # this should only happen for out-of-band membership events
+                if not internal_metadata.get("out_of_band_membership"):
+                    logger.warning(
+                        "Room %s for event %s is unknown", d["room_id"], event_id
+                    )
+                    continue
+
+                # take a wild stab at the room version based on the event format
+                if format_version == EventFormatVersions.V1:
+                    room_version = RoomVersions.V1
+                elif format_version == EventFormatVersions.V2:
+                    room_version = RoomVersions.V3
+                else:
+                    room_version = RoomVersions.V5
+            else:
+                room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+                if not room_version:
+                    logger.error(
+                        "Event %s in room %s has unknown room version %s",
+                        event_id,
+                        d["room_id"],
+                        room_version_id,
+                    )
+                    continue
+
+                if room_version.event_format != format_version:
+                    logger.error(
+                        "Event %s in room %s with version %s has wrong format: "
+                        "expected %s, was %s",
+                        event_id,
+                        d["room_id"],
+                        room_version_id,
+                        room_version.event_format,
+                        format_version,
+                    )
+                    continue
+
+            original_ev = make_event_from_dict(
                 event_dict=d,
+                room_version=room_version,
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
 
             event_map[event_id] = original_ev
 
-        # finally, we can decide whether each one nededs redacting, and build
+        # finally, we can decide whether each one needs redacting, and build
         # the cache entries.
         result_map = {}
         for event_id, original_ev in event_map.items():
@@ -558,7 +716,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         if should_start:
             run_as_background_process(
-                "fetch_events", self.runWithConnection, self._do_fetch
+                "fetch_events", self.db.runWithConnection, self._do_fetch
             )
 
         logger.debug("Loading %d events: %s", len(events), events)
@@ -585,6 +743,12 @@ class EventsWorkerStore(SQLBaseStore):
            of EventFormatVersions. 'None' means the event predates
            EventFormatVersions (so the event is format V1).
 
+         * room_version_id (str|None): The version of the room which contains the event.
+           Hopefully one of RoomVersions.
+
+           Due to historical reasons, there may be a few events in the database which
+           do not have an associated room; in this case None will be returned here.
+
          * rejected_reason (str|None): if the event was rejected, the reason
            why.
 
@@ -600,19 +764,24 @@ class EventsWorkerStore(SQLBaseStore):
         """
         event_dict = {}
         for evs in batch_iter(event_ids, 200):
-            sql = (
-                "SELECT "
-                " e.event_id, "
-                " e.internal_metadata,"
-                " e.json,"
-                " e.format_version, "
-                " rej.reason "
-                " FROM event_json as e"
-                " LEFT JOIN rejections as rej USING (event_id)"
-                " WHERE e.event_id IN (%s)"
-            ) % (",".join(["?"] * len(evs)),)
-
-            txn.execute(sql, evs)
+            sql = """\
+                SELECT
+                  e.event_id,
+                  e.internal_metadata,
+                  e.json,
+                  e.format_version,
+                  r.room_version,
+                  rej.reason
+                FROM event_json as e
+                  LEFT JOIN rooms r USING (room_id)
+                  LEFT JOIN rejections as rej USING (event_id)
+                WHERE """
+
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "e.event_id", evs
+            )
+
+            txn.execute(sql + clause, args)
 
             for row in txn:
                 event_id = row[0]
@@ -621,16 +790,17 @@ class EventsWorkerStore(SQLBaseStore):
                     "internal_metadata": row[1],
                     "json": row[2],
                     "format_version": row[3],
-                    "rejected_reason": row[4],
+                    "room_version_id": row[4],
+                    "rejected_reason": row[5],
                     "redactions": [],
                 }
 
             # check for redactions
-            redactions_sql = (
-                "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
-            ) % (",".join(["?"] * len(evs)),)
+            redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
 
-            txn.execute(redactions_sql, evs)
+            clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
+
+            txn.execute(redactions_sql + clause, args)
 
             for (redacter, redacted) in txn:
                 d = event_dict.get(redacted)
@@ -715,7 +885,7 @@ class EventsWorkerStore(SQLBaseStore):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="events",
             retcols=("event_id",),
             column="event_id",
@@ -724,7 +894,7 @@ class EventsWorkerStore(SQLBaseStore):
             desc="have_events_in_timeline",
         )
 
-        return set(r["event_id"] for r in rows)
+        return {r["event_id"] for r in rows}
 
     @defer.inlineCallbacks
     def have_seen_events(self, event_ids):
@@ -739,52 +909,21 @@ class EventsWorkerStore(SQLBaseStore):
         results = set()
 
         def have_seen_events_txn(txn, chunk):
-            sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
-                ",".join("?" * len(chunk)),
+            sql = "SELECT event_id FROM events as e WHERE "
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "e.event_id", chunk
             )
-            txn.execute(sql, chunk)
+            txn.execute(sql + clause, args)
             for (event_id,) in txn:
                 results.add(event_id)
 
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
-        return results
-
-    def get_seen_events_with_rejections(self, event_ids):
-        """Given a list of event ids, check if we rejected them.
-
-        Args:
-            event_ids (list[str])
-
-        Returns:
-            Deferred[dict[str, str|None):
-                Has an entry for each event id we already have seen. Maps to
-                the rejected reason string if we rejected the event, else maps
-                to None.
-        """
-        if not event_ids:
-            return defer.succeed({})
-
-        def f(txn):
-            sql = (
-                "SELECT e.event_id, reason FROM events as e "
-                "LEFT JOIN rejections as r ON e.event_id = r.event_id "
-                "WHERE e.event_id = ?"
+            yield self.db.runInteraction(
+                "have_seen_events", have_seen_events_txn, chunk
             )
-
-            res = {}
-            for event_id in event_ids:
-                txn.execute(sql, (event_id,))
-                row = txn.fetchone()
-                if row:
-                    _, rejected = row
-                    res[event_id] = rejected
-
-            return res
-
-        return self.runInteraction("get_seen_events_with_rejections", f)
+        return results
 
     def _get_total_state_event_counts_txn(self, txn, room_id):
         """
@@ -810,7 +949,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred[int]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_total_state_event_counts",
             self._get_total_state_event_counts_txn,
             room_id,
@@ -835,7 +974,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred[int]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_current_state_event_counts",
             self._get_current_state_event_counts_txn,
             room_id,
@@ -862,3 +1001,343 @@ class EventsWorkerStore(SQLBaseStore):
         complexity_v1 = round(state_events / 500, 2)
 
         return {"v1": complexity_v1}
+
+    def get_current_backfill_token(self):
+        """The current minimum token that backfilled events have reached"""
+        return -self._backfill_id_gen.get_current_token()
+
+    def get_current_events_token(self):
+        """The current maximum token that events have reached"""
+        return self._stream_id_gen.get_current_token()
+
+    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+        """Returns new events, for the Events replication stream
+
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
+        Returns: Deferred[List[Tuple]]
+            a list of events stream rows. Each tuple consists of a stream id as
+            the first element, followed by fields suitable for casting into an
+            EventsStreamRow.
+        """
+
+        def get_all_new_forward_event_rows(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+        )
+
+    def get_ex_outlier_stream_rows(self, last_id, current_id):
+        """Returns de-outliered events, for the Events replication stream
+
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+
+        Returns: Deferred[List[Tuple]]
+            a list of events stream rows. Each tuple consists of a stream id as
+            the first element, followed by fields suitable for casting into an
+            EventsStreamRow.
+        """
+
+        def get_ex_outlier_stream_rows_txn(txn):
+            sql = (
+                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < event_stream_ordering"
+                " AND event_stream_ordering <= ?"
+                " ORDER BY event_stream_ordering ASC"
+            )
+
+            txn.execute(sql, (last_id, current_id))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
+        )
+
+    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_backfill_event_rows(txn):
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (-last_id, -current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > event_stream_ordering"
+                " AND event_stream_ordering >= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (-last_id, -upper_bound))
+            new_event_updates.extend(txn.fetchall())
+
+            return new_event_updates
+
+        return self.db.runInteraction(
+            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+        )
+
+    async def get_all_updated_current_state_deltas(
+        self, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple], int, bool]:
+        """Fetch updates from current_state_delta_stream
+
+        Args:
+            from_token: The previous stream token. Updates from this stream id will
+                be excluded.
+
+            to_token: The current stream token (ie the upper limit). Updates up to this
+                stream id will be included (modulo the 'limit' param)
+
+            target_row_count: The number of rows to try to return. If more rows are
+                available, we will set 'limited' in the result. In the event of a large
+                batch, we may return more rows than this.
+        Returns:
+            A triplet `(updates, new_last_token, limited)`, where:
+               * `updates` is a list of database tuples.
+               * `new_last_token` is the new position in stream.
+               * `limited` is whether there are more updates to fetch.
+        """
+
+        def get_all_updated_current_state_deltas_txn(txn):
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id
+                FROM current_state_delta_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, target_row_count))
+            return txn.fetchall()
+
+        def get_deltas_for_stream_id_txn(txn, stream_id):
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id
+                FROM current_state_delta_stream
+                WHERE stream_id = ?
+            """
+            txn.execute(sql, [stream_id])
+            return txn.fetchall()
+
+        # we need to make sure that, for every stream id in the results, we get *all*
+        # the rows with that stream id.
+
+        rows = await self.db.runInteraction(
+            "get_all_updated_current_state_deltas",
+            get_all_updated_current_state_deltas_txn,
+        )  # type: List[Tuple]
+
+        # if we've got fewer rows than the limit, we're good
+        if len(rows) < target_row_count:
+            return rows, to_token, False
+
+        # we hit the limit, so reduce the upper limit so that we exclude the stream id
+        # of the last row in the result.
+        assert rows[-1][0] <= to_token
+        to_token = rows[-1][0] - 1
+
+        # search backwards through the list for the point to truncate
+        for idx in range(len(rows) - 1, 0, -1):
+            if rows[idx - 1][0] <= to_token:
+                return rows[:idx], to_token, True
+
+        # bother. We didn't get a full set of changes for even a single
+        # stream id. let's run the query again, without a row limit, but for
+        # just one stream id.
+        to_token += 1
+        rows = await self.db.runInteraction(
+            "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
+        )
+
+        return rows, to_token, True
+
+    @cached(num_args=5, max_entries=10)
+    def get_all_new_events(
+        self,
+        last_backfill_id,
+        last_forward_id,
+        current_backfill_id,
+        current_forward_id,
+        limit,
+    ):
+        """Get all the new events that have arrived at the server either as
+        new events or as backfilled events"""
+        have_backfill_events = last_backfill_id != current_backfill_id
+        have_forward_events = last_forward_id != current_forward_id
+
+        if not have_backfill_events and not have_forward_events:
+            return defer.succeed(AllNewEventsResult([], [], [], [], []))
+
+        def get_all_new_events_txn(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            if have_forward_events:
+                txn.execute(sql, (last_forward_id, current_forward_id, limit))
+                new_forward_events = txn.fetchall()
+
+                if len(new_forward_events) == limit:
+                    upper_bound = new_forward_events[-1][0]
+                else:
+                    upper_bound = current_forward_id
+
+                sql = (
+                    "SELECT event_stream_ordering, event_id, state_group"
+                    " FROM ex_outlier_stream"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+                txn.execute(sql, (last_forward_id, upper_bound))
+                forward_ex_outliers = txn.fetchall()
+            else:
+                new_forward_events = []
+                forward_ex_outliers = []
+
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering DESC"
+                " LIMIT ?"
+            )
+            if have_backfill_events:
+                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+                new_backfill_events = txn.fetchall()
+
+                if len(new_backfill_events) == limit:
+                    upper_bound = new_backfill_events[-1][0]
+                else:
+                    upper_bound = current_backfill_id
+
+                sql = (
+                    "SELECT -event_stream_ordering, event_id, state_group"
+                    " FROM ex_outlier_stream"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+                txn.execute(sql, (-last_backfill_id, -upper_bound))
+                backward_ex_outliers = txn.fetchall()
+            else:
+                new_backfill_events = []
+                backward_ex_outliers = []
+
+            return AllNewEventsResult(
+                new_forward_events,
+                new_backfill_events,
+                forward_ex_outliers,
+                backward_ex_outliers,
+            )
+
+        return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
+
+    async def is_event_after(self, event_id1, event_id2):
+        """Returns True if event_id1 is after event_id2 in the stream
+        """
+        to_1, so_1 = await self.get_event_ordering(event_id1)
+        to_2, so_2 = await self.get_event_ordering(event_id2)
+        return (to_1, so_1) > (to_2, so_2)
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def get_event_ordering(self, event_id):
+        res = yield self.db.simple_select_one(
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True,
+        )
+
+        if not res:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+
+    def get_next_event_to_expire(self):
+        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+        table, or None if there's no more event to expire.
+
+        Returns: Deferred[Optional[Tuple[str, int]]]
+            A tuple containing the event ID as its first element and an expiry timestamp
+            as its second one, if there's at least one row in the event_expiry table.
+            None otherwise.
+        """
+
+        def get_next_event_to_expire_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, expiry_ts FROM event_expiry
+                ORDER BY expiry_ts ASC LIMIT 1
+                """
+            )
+
+            return txn.fetchone()
+
+        return self.db.runInteraction(
+            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        )
+
+
+AllNewEventsResult = namedtuple(
+    "AllNewEventsResult",
+    [
+        "new_forward_events",
+        "new_backfill_events",
+        "forward_ex_outliers",
+        "backward_ex_outliers",
+    ],
+)
diff --git a/synapse/storage/filtering.py b/synapse/storage/data_stores/main/filtering.py
index 23b48f6cea..342d6622a4 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -16,10 +16,9 @@
 from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
-from ._base import SQLBaseStore, db_to_json
-
 
 class FilteringStore(SQLBaseStore):
     @cachedInlineCallbacks(num_args=2)
@@ -31,7 +30,7 @@ class FilteringStore(SQLBaseStore):
         except ValueError:
             raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
 
-        def_json = yield self._simple_select_one_onecol(
+        def_json = yield self.db.simple_select_one_onecol(
             table="user_filters",
             keyvalues={"user_id": user_localpart, "filter_id": filter_id},
             retcol="filter_json",
@@ -51,12 +50,12 @@ class FilteringStore(SQLBaseStore):
                 "SELECT filter_id FROM user_filters "
                 "WHERE user_id = ? AND filter_json = ?"
             )
-            txn.execute(sql, (user_localpart, def_json))
+            txn.execute(sql, (user_localpart, bytearray(def_json)))
             filter_id_response = txn.fetchone()
             if filter_id_response is not None:
                 return filter_id_response[0]
 
-            sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
+            sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
             txn.execute(sql, (user_localpart,))
             max_id = txn.fetchone()[0]
             if max_id is None:
@@ -68,8 +67,8 @@ class FilteringStore(SQLBaseStore):
                 "INSERT INTO user_filters (user_id, filter_id, filter_json)"
                 "VALUES(?, ?, ?)"
             )
-            txn.execute(sql, (user_localpart, filter_id, def_json))
+            txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
 
             return filter_id
 
-        return self.runInteraction("add_user_filter", _do_txn)
+        return self.db.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 15b01c6958..fb1361f1c1 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -19,8 +19,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 # The category ID for the "default" category. We don't store as null in the
 # database to avoid the fun of null != null
@@ -28,23 +27,9 @@ _DEFAULT_CATEGORY_ID = ""
 _DEFAULT_ROLE_ID = ""
 
 
-class GroupServerStore(SQLBaseStore):
-    def set_group_join_policy(self, group_id, join_policy):
-        """Set the join policy of a group.
-
-        join_policy can be one of:
-         * "invite"
-         * "open"
-        """
-        return self._simple_update_one(
-            table="groups",
-            keyvalues={"group_id": group_id},
-            updatevalues={"join_policy": join_policy},
-            desc="set_group_join_policy",
-        )
-
+class GroupServerWorkerStore(SQLBaseStore):
     def get_group(self, group_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="groups",
             keyvalues={"group_id": group_id},
             retcols=(
@@ -66,7 +51,7 @@ class GroupServerStore(SQLBaseStore):
         if not include_private:
             keyvalues["is_public"] = True
 
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="group_users",
             keyvalues=keyvalues,
             retcols=("user_id", "is_public", "is_admin"),
@@ -76,31 +61,85 @@ class GroupServerStore(SQLBaseStore):
     def get_invited_users_in_group(self, group_id):
         # TODO: Pagination
 
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id},
             retcol="user_id",
             desc="get_invited_users_in_group",
         )
 
-    def get_rooms_in_group(self, group_id, include_private=False):
+    def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+        """Retrieve the rooms that belong to a given group. Does not return rooms that
+        lack members.
+
+        Args:
+            group_id: The ID of the group to query for rooms
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
+            form of:
+
+            {
+              "room_id": "!a_room_id:example.com",  # The ID of the room
+              "is_public": False                    # Whether this is a public room or not
+            }
+        """
         # TODO: Pagination
 
-        keyvalues = {"group_id": group_id}
-        if not include_private:
-            keyvalues["is_public"] = True
+        def _get_rooms_in_group_txn(txn):
+            sql = """
+            SELECT room_id, is_public FROM group_rooms
+                WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
+            """
+            args = [group_id, False]
 
-        return self._simple_select_list(
-            table="group_rooms",
-            keyvalues=keyvalues,
-            retcols=("room_id", "is_public"),
-            desc="get_rooms_in_group",
-        )
+            if not include_private:
+                sql += " AND is_public = ?"
+                args += [True]
+
+            txn.execute(sql, args)
+
+            return [
+                {"room_id": room_id, "is_public": is_public}
+                for room_id, is_public in txn
+            ]
 
-    def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+        return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+
+    def get_rooms_for_summary_by_category(
+        self, group_id: str, include_private: bool = False,
+    ):
         """Get the rooms and categories that should be included in a summary request
 
-        Returns ([rooms], [categories])
+        Args:
+            group_id: The ID of the group to query the summary for
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[Tuple[List, Dict]]: A tuple containing:
+
+                * A list of dictionaries with the keys:
+                    * "room_id": str, the room ID
+                    * "is_public": bool, whether the room is public
+                    * "category_id": str|None, the category ID if set, else None
+                    * "order": int, the sort order of rooms
+
+                * A dictionary with the key:
+                    * category_id (str): a dictionary with the keys:
+                        * "is_public": bool, whether the category is public
+                        * "profile": str, the category profile
+                        * "order": int, the sort order of rooms in this category
         """
 
         def _get_rooms_for_summary_txn(txn):
@@ -112,13 +151,23 @@ class GroupServerStore(SQLBaseStore):
                 SELECT room_id, is_public, category_id, room_order
                 FROM group_summary_rooms
                 WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
             """
 
             if not include_private:
                 sql += " AND is_public = ?"
-                txn.execute(sql, (group_id, True))
+                txn.execute(sql, (group_id, False, True))
             else:
-                txn.execute(sql, (group_id,))
+                txn.execute(sql, (group_id, False))
 
             rooms = [
                 {
@@ -154,10 +203,372 @@ class GroupServerStore(SQLBaseStore):
 
             return rooms, categories
 
-        return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
+        return self.db.runInteraction(
+            "get_rooms_for_summary", _get_rooms_for_summary_txn
+        )
+
+    @defer.inlineCallbacks
+    def get_group_categories(self, group_id):
+        rows = yield self.db.simple_select_list(
+            table="group_room_categories",
+            keyvalues={"group_id": group_id},
+            retcols=("category_id", "is_public", "profile"),
+            desc="get_group_categories",
+        )
+
+        return {
+            row["category_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
+            }
+            for row in rows
+        }
+
+    @defer.inlineCallbacks
+    def get_group_category(self, group_id, category_id):
+        category = yield self.db.simple_select_one(
+            table="group_room_categories",
+            keyvalues={"group_id": group_id, "category_id": category_id},
+            retcols=("is_public", "profile"),
+            desc="get_group_category",
+        )
+
+        category["profile"] = json.loads(category["profile"])
+
+        return category
+
+    @defer.inlineCallbacks
+    def get_group_roles(self, group_id):
+        rows = yield self.db.simple_select_list(
+            table="group_roles",
+            keyvalues={"group_id": group_id},
+            retcols=("role_id", "is_public", "profile"),
+            desc="get_group_roles",
+        )
+
+        return {
+            row["role_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
+            }
+            for row in rows
+        }
+
+    @defer.inlineCallbacks
+    def get_group_role(self, group_id, role_id):
+        role = yield self.db.simple_select_one(
+            table="group_roles",
+            keyvalues={"group_id": group_id, "role_id": role_id},
+            retcols=("is_public", "profile"),
+            desc="get_group_role",
+        )
+
+        role["profile"] = json.loads(role["profile"])
+
+        return role
+
+    def get_local_groups_for_room(self, room_id):
+        """Get all of the local group that contain a given room
+        Args:
+            room_id (str): The ID of a room
+        Returns:
+            Deferred[list[str]]: A twisted.Deferred containing a list of group ids
+                containing this room
+        """
+        return self.db.simple_select_onecol(
+            table="group_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="group_id",
+            desc="get_local_groups_for_room",
+        )
+
+    def get_users_for_summary_by_role(self, group_id, include_private=False):
+        """Get the users and roles that should be included in a summary request
+
+        Returns ([users], [roles])
+        """
+
+        def _get_users_for_summary_txn(txn):
+            keyvalues = {"group_id": group_id}
+            if not include_private:
+                keyvalues["is_public"] = True
+
+            sql = """
+                SELECT user_id, is_public, role_id, user_order
+                FROM group_summary_users
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            users = [
+                {
+                    "user_id": row[0],
+                    "is_public": row[1],
+                    "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+                    "order": row[3],
+                }
+                for row in txn
+            ]
+
+            sql = """
+                SELECT role_id, is_public, profile, role_order
+                FROM group_summary_roles
+                INNER JOIN group_roles USING (group_id, role_id)
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            roles = {
+                row[0]: {
+                    "is_public": row[1],
+                    "profile": json.loads(row[2]),
+                    "order": row[3],
+                }
+                for row in txn
+            }
+
+            return users, roles
+
+        return self.db.runInteraction(
+            "get_users_for_summary_by_role", _get_users_for_summary_txn
+        )
+
+    def is_user_in_group(self, user_id, group_id):
+        return self.db.simple_select_one_onecol(
+            table="group_users",
+            keyvalues={"group_id": group_id, "user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="is_user_in_group",
+        ).addCallback(lambda r: bool(r))
+
+    def is_user_admin_in_group(self, group_id, user_id):
+        return self.db.simple_select_one_onecol(
+            table="group_users",
+            keyvalues={"group_id": group_id, "user_id": user_id},
+            retcol="is_admin",
+            allow_none=True,
+            desc="is_user_admin_in_group",
+        )
+
+    def is_user_invited_to_local_group(self, group_id, user_id):
+        """Has the group server invited a user?
+        """
+        return self.db.simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={"group_id": group_id, "user_id": user_id},
+            retcol="user_id",
+            desc="is_user_invited_to_local_group",
+            allow_none=True,
+        )
+
+    def get_users_membership_info_in_group(self, group_id, user_id):
+        """Get a dict describing the membership of a user in a group.
+
+        Example if joined:
+
+            {
+                "membership": "join",
+                "is_public": True,
+                "is_privileged": False,
+            }
+
+        Returns an empty dict if the user is not join/invite/etc
+        """
+
+        def _get_users_membership_in_group_txn(txn):
+            row = self.db.simple_select_one_txn(
+                txn,
+                table="group_users",
+                keyvalues={"group_id": group_id, "user_id": user_id},
+                retcols=("is_admin", "is_public"),
+                allow_none=True,
+            )
+
+            if row:
+                return {
+                    "membership": "join",
+                    "is_public": row["is_public"],
+                    "is_privileged": row["is_admin"],
+                }
+
+            row = self.db.simple_select_one_onecol_txn(
+                txn,
+                table="group_invites",
+                keyvalues={"group_id": group_id, "user_id": user_id},
+                retcol="user_id",
+                allow_none=True,
+            )
+
+            if row:
+                return {"membership": "invite"}
+
+            return {}
+
+        return self.db.runInteraction(
+            "get_users_membership_info_in_group", _get_users_membership_in_group_txn
+        )
+
+    def get_publicised_groups_for_user(self, user_id):
+        """Get all groups a user is publicising
+        """
+        return self.db.simple_select_onecol(
+            table="local_group_membership",
+            keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
+            retcol="group_id",
+            desc="get_publicised_groups_for_user",
+        )
+
+    def get_attestations_need_renewals(self, valid_until_ms):
+        """Get all attestations that need to be renewed until givent time
+        """
+
+        def _get_attestations_need_renewals_txn(txn):
+            sql = """
+                SELECT group_id, user_id FROM group_attestations_renewals
+                WHERE valid_until_ms <= ?
+            """
+            txn.execute(sql, (valid_until_ms,))
+            return self.db.cursor_to_dict(txn)
+
+        return self.db.runInteraction(
+            "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+        )
+
+    @defer.inlineCallbacks
+    def get_remote_attestation(self, group_id, user_id):
+        """Get the attestation that proves the remote agrees that the user is
+        in the group.
+        """
+        row = yield self.db.simple_select_one(
+            table="group_attestations_remote",
+            keyvalues={"group_id": group_id, "user_id": user_id},
+            retcols=("valid_until_ms", "attestation_json"),
+            desc="get_remote_attestation",
+            allow_none=True,
+        )
+
+        now = int(self._clock.time_msec())
+        if row and now < row["valid_until_ms"]:
+            return json.loads(row["attestation_json"])
+
+        return None
+
+    def get_joined_groups(self, user_id):
+        return self.db.simple_select_onecol(
+            table="local_group_membership",
+            keyvalues={"user_id": user_id, "membership": "join"},
+            retcol="group_id",
+            desc="get_joined_groups",
+        )
+
+    def get_all_groups_for_user(self, user_id, now_token):
+        def _get_all_groups_for_user_txn(txn):
+            sql = """
+                SELECT group_id, type, membership, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND membership != 'leave'
+                    AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, now_token))
+            return [
+                {
+                    "group_id": row[0],
+                    "type": row[1],
+                    "membership": row[2],
+                    "content": json.loads(row[3]),
+                }
+                for row in txn
+            ]
+
+        return self.db.runInteraction(
+            "get_all_groups_for_user", _get_all_groups_for_user_txn
+        )
+
+    def get_groups_changes_for_user(self, user_id, from_token, to_token):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_entity_changed(
+            user_id, from_token
+        )
+        if not has_changed:
+            return defer.succeed([])
+
+        def _get_groups_changes_for_user_txn(txn):
+            sql = """
+                SELECT group_id, membership, type, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, from_token, to_token))
+            return [
+                {
+                    "group_id": group_id,
+                    "membership": membership,
+                    "type": gtype,
+                    "content": json.loads(content_json),
+                }
+                for group_id, membership, gtype, content_json in txn
+            ]
+
+        return self.db.runInteraction(
+            "get_groups_changes_for_user", _get_groups_changes_for_user_txn
+        )
+
+    def get_all_groups_changes(self, from_token, to_token, limit):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+            from_token
+        )
+        if not has_changed:
+            return defer.succeed([])
+
+        def _get_all_groups_changes_txn(txn):
+            sql = """
+                SELECT stream_id, group_id, user_id, type, content
+                FROM local_group_updates
+                WHERE ? < stream_id AND stream_id <= ?
+                LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, limit))
+            return [
+                (stream_id, group_id, user_id, gtype, json.loads(content_json))
+                for stream_id, group_id, user_id, gtype, content_json in txn
+            ]
+
+        return self.db.runInteraction(
+            "get_all_groups_changes", _get_all_groups_changes_txn
+        )
+
+
+class GroupServerStore(GroupServerWorkerStore):
+    def set_group_join_policy(self, group_id, join_policy):
+        """Set the join policy of a group.
+
+        join_policy can be one of:
+         * "invite"
+         * "open"
+        """
+        return self.db.simple_update_one(
+            table="groups",
+            keyvalues={"group_id": group_id},
+            updatevalues={"join_policy": join_policy},
+            desc="set_group_join_policy",
+        )
 
     def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_room_to_summary",
             self._add_room_to_summary_txn,
             group_id,
@@ -181,7 +592,7 @@ class GroupServerStore(SQLBaseStore):
                 an order of 1 will put the room first. Otherwise, the room gets
                 added to the end.
         """
-        room_in_group = self._simple_select_one_onecol_txn(
+        room_in_group = self.db.simple_select_one_onecol_txn(
             txn,
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
@@ -194,7 +605,7 @@ class GroupServerStore(SQLBaseStore):
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
         else:
-            cat_exists = self._simple_select_one_onecol_txn(
+            cat_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_room_categories",
                 keyvalues={"group_id": group_id, "category_id": category_id},
@@ -205,7 +616,7 @@ class GroupServerStore(SQLBaseStore):
                 raise SynapseError(400, "Category doesn't exist")
 
             # TODO: Check category is part of summary already
-            cat_exists = self._simple_select_one_onecol_txn(
+            cat_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_summary_room_categories",
                 keyvalues={"group_id": group_id, "category_id": category_id},
@@ -225,7 +636,7 @@ class GroupServerStore(SQLBaseStore):
                     (group_id, category_id, group_id, category_id),
                 )
 
-        existing = self._simple_select_one_txn(
+        existing = self.db.simple_select_one_txn(
             txn,
             table="group_summary_rooms",
             keyvalues={
@@ -250,7 +661,7 @@ class GroupServerStore(SQLBaseStore):
                 WHERE group_id = ? AND category_id = ?
             """
             txn.execute(sql, (group_id, category_id))
-            order, = txn.fetchone()
+            (order,) = txn.fetchone()
 
         if existing:
             to_update = {}
@@ -258,7 +669,7 @@ class GroupServerStore(SQLBaseStore):
                 to_update["room_order"] = order
             if is_public is not None:
                 to_update["is_public"] = is_public
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="group_summary_rooms",
                 keyvalues={
@@ -272,7 +683,7 @@ class GroupServerStore(SQLBaseStore):
             if is_public is None:
                 is_public = True
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_summary_rooms",
                 values={
@@ -288,7 +699,7 @@ class GroupServerStore(SQLBaseStore):
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
 
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_summary_rooms",
             keyvalues={
                 "group_id": group_id,
@@ -298,36 +709,6 @@ class GroupServerStore(SQLBaseStore):
             desc="remove_room_from_summary",
         )
 
-    @defer.inlineCallbacks
-    def get_group_categories(self, group_id):
-        rows = yield self._simple_select_list(
-            table="group_room_categories",
-            keyvalues={"group_id": group_id},
-            retcols=("category_id", "is_public", "profile"),
-            desc="get_group_categories",
-        )
-
-        return {
-            row["category_id"]: {
-                "is_public": row["is_public"],
-                "profile": json.loads(row["profile"]),
-            }
-            for row in rows
-        }
-
-    @defer.inlineCallbacks
-    def get_group_category(self, group_id, category_id):
-        category = yield self._simple_select_one(
-            table="group_room_categories",
-            keyvalues={"group_id": group_id, "category_id": category_id},
-            retcols=("is_public", "profile"),
-            desc="get_group_category",
-        )
-
-        category["profile"] = json.loads(category["profile"])
-
-        return category
-
     def upsert_group_category(self, group_id, category_id, profile, is_public):
         """Add/update room category for group
         """
@@ -344,7 +725,7 @@ class GroupServerStore(SQLBaseStore):
         else:
             update_values["is_public"] = is_public
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             values=update_values,
@@ -353,42 +734,12 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def remove_group_category(self, group_id, category_id):
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             desc="remove_group_category",
         )
 
-    @defer.inlineCallbacks
-    def get_group_roles(self, group_id):
-        rows = yield self._simple_select_list(
-            table="group_roles",
-            keyvalues={"group_id": group_id},
-            retcols=("role_id", "is_public", "profile"),
-            desc="get_group_roles",
-        )
-
-        return {
-            row["role_id"]: {
-                "is_public": row["is_public"],
-                "profile": json.loads(row["profile"]),
-            }
-            for row in rows
-        }
-
-    @defer.inlineCallbacks
-    def get_group_role(self, group_id, role_id):
-        role = yield self._simple_select_one(
-            table="group_roles",
-            keyvalues={"group_id": group_id, "role_id": role_id},
-            retcols=("is_public", "profile"),
-            desc="get_group_role",
-        )
-
-        role["profile"] = json.loads(role["profile"])
-
-        return role
-
     def upsert_group_role(self, group_id, role_id, profile, is_public):
         """Add/remove user role
         """
@@ -405,7 +756,7 @@ class GroupServerStore(SQLBaseStore):
         else:
             update_values["is_public"] = is_public
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             values=update_values,
@@ -414,14 +765,14 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def remove_group_role(self, group_id, role_id):
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             desc="remove_group_role",
         )
 
     def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_user_to_summary",
             self._add_user_to_summary_txn,
             group_id,
@@ -445,7 +796,7 @@ class GroupServerStore(SQLBaseStore):
                 an order of 1 will put the user first. Otherwise, the user gets
                 added to the end.
         """
-        user_in_group = self._simple_select_one_onecol_txn(
+        user_in_group = self.db.simple_select_one_onecol_txn(
             txn,
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -458,7 +809,7 @@ class GroupServerStore(SQLBaseStore):
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
         else:
-            role_exists = self._simple_select_one_onecol_txn(
+            role_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_roles",
                 keyvalues={"group_id": group_id, "role_id": role_id},
@@ -469,7 +820,7 @@ class GroupServerStore(SQLBaseStore):
                 raise SynapseError(400, "Role doesn't exist")
 
             # TODO: Check role is part of the summary already
-            role_exists = self._simple_select_one_onecol_txn(
+            role_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_summary_roles",
                 keyvalues={"group_id": group_id, "role_id": role_id},
@@ -489,7 +840,7 @@ class GroupServerStore(SQLBaseStore):
                     (group_id, role_id, group_id, role_id),
                 )
 
-        existing = self._simple_select_one_txn(
+        existing = self.db.simple_select_one_txn(
             txn,
             table="group_summary_users",
             keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -510,7 +861,7 @@ class GroupServerStore(SQLBaseStore):
                 WHERE group_id = ? AND role_id = ?
             """
             txn.execute(sql, (group_id, role_id))
-            order, = txn.fetchone()
+            (order,) = txn.fetchone()
 
         if existing:
             to_update = {}
@@ -518,7 +869,7 @@ class GroupServerStore(SQLBaseStore):
                 to_update["user_order"] = order
             if is_public is not None:
                 to_update["is_public"] = is_public
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="group_summary_users",
                 keyvalues={
@@ -532,7 +883,7 @@ class GroupServerStore(SQLBaseStore):
             if is_public is None:
                 is_public = True
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_summary_users",
                 values={
@@ -548,158 +899,21 @@ class GroupServerStore(SQLBaseStore):
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
 
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_summary_users",
             keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
             desc="remove_user_from_summary",
         )
 
-    def get_users_for_summary_by_role(self, group_id, include_private=False):
-        """Get the users and roles that should be included in a summary request
-
-        Returns ([users], [roles])
-        """
-
-        def _get_users_for_summary_txn(txn):
-            keyvalues = {"group_id": group_id}
-            if not include_private:
-                keyvalues["is_public"] = True
-
-            sql = """
-                SELECT user_id, is_public, role_id, user_order
-                FROM group_summary_users
-                WHERE group_id = ?
-            """
-
-            if not include_private:
-                sql += " AND is_public = ?"
-                txn.execute(sql, (group_id, True))
-            else:
-                txn.execute(sql, (group_id,))
-
-            users = [
-                {
-                    "user_id": row[0],
-                    "is_public": row[1],
-                    "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
-                    "order": row[3],
-                }
-                for row in txn
-            ]
-
-            sql = """
-                SELECT role_id, is_public, profile, role_order
-                FROM group_summary_roles
-                INNER JOIN group_roles USING (group_id, role_id)
-                WHERE group_id = ?
-            """
-
-            if not include_private:
-                sql += " AND is_public = ?"
-                txn.execute(sql, (group_id, True))
-            else:
-                txn.execute(sql, (group_id,))
-
-            roles = {
-                row[0]: {
-                    "is_public": row[1],
-                    "profile": json.loads(row[2]),
-                    "order": row[3],
-                }
-                for row in txn
-            }
-
-            return users, roles
-
-        return self.runInteraction(
-            "get_users_for_summary_by_role", _get_users_for_summary_txn
-        )
-
-    def is_user_in_group(self, user_id, group_id):
-        return self._simple_select_one_onecol(
-            table="group_users",
-            keyvalues={"group_id": group_id, "user_id": user_id},
-            retcol="user_id",
-            allow_none=True,
-            desc="is_user_in_group",
-        ).addCallback(lambda r: bool(r))
-
-    def is_user_admin_in_group(self, group_id, user_id):
-        return self._simple_select_one_onecol(
-            table="group_users",
-            keyvalues={"group_id": group_id, "user_id": user_id},
-            retcol="is_admin",
-            allow_none=True,
-            desc="is_user_admin_in_group",
-        )
-
     def add_group_invite(self, group_id, user_id):
         """Record that the group server has invited a user
         """
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="group_invites",
             values={"group_id": group_id, "user_id": user_id},
             desc="add_group_invite",
         )
 
-    def is_user_invited_to_local_group(self, group_id, user_id):
-        """Has the group server invited a user?
-        """
-        return self._simple_select_one_onecol(
-            table="group_invites",
-            keyvalues={"group_id": group_id, "user_id": user_id},
-            retcol="user_id",
-            desc="is_user_invited_to_local_group",
-            allow_none=True,
-        )
-
-    def get_users_membership_info_in_group(self, group_id, user_id):
-        """Get a dict describing the membership of a user in a group.
-
-        Example if joined:
-
-            {
-                "membership": "join",
-                "is_public": True,
-                "is_privileged": False,
-            }
-
-        Returns an empty dict if the user is not join/invite/etc
-        """
-
-        def _get_users_membership_in_group_txn(txn):
-            row = self._simple_select_one_txn(
-                txn,
-                table="group_users",
-                keyvalues={"group_id": group_id, "user_id": user_id},
-                retcols=("is_admin", "is_public"),
-                allow_none=True,
-            )
-
-            if row:
-                return {
-                    "membership": "join",
-                    "is_public": row["is_public"],
-                    "is_privileged": row["is_admin"],
-                }
-
-            row = self._simple_select_one_onecol_txn(
-                txn,
-                table="group_invites",
-                keyvalues={"group_id": group_id, "user_id": user_id},
-                retcol="user_id",
-                allow_none=True,
-            )
-
-            if row:
-                return {"membership": "invite"}
-
-            return {}
-
-        return self.runInteraction(
-            "get_users_membership_info_in_group", _get_users_membership_in_group_txn
-        )
-
     def add_user_to_group(
         self,
         group_id,
@@ -724,7 +938,7 @@ class GroupServerStore(SQLBaseStore):
         """
 
         def _add_user_to_group_txn(txn):
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_users",
                 values={
@@ -735,14 +949,14 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
 
             if local_attestation:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="group_attestations_renewals",
                     values={
@@ -752,7 +966,7 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
             if remote_attestation:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="group_attestations_remote",
                     values={
@@ -763,49 +977,49 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
 
-        return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
+        return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
 
     def remove_user_from_group(self, group_id, user_id):
         def _remove_user_from_group_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_attestations_renewals",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_attestations_remote",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_summary_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_user_from_group", _remove_user_from_group_txn
         )
 
     def add_room_to_group(self, group_id, room_id, is_public):
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="group_rooms",
             values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
             desc="add_room_to_group",
         )
 
     def update_room_in_group_visibility(self, group_id, room_id, is_public):
-        return self._simple_update(
+        return self.db.simple_update(
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
             updatevalues={"is_public": is_public},
@@ -814,36 +1028,26 @@ class GroupServerStore(SQLBaseStore):
 
     def remove_room_from_group(self, group_id, room_id):
         def _remove_room_from_group_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_rooms",
                 keyvalues={"group_id": group_id, "room_id": room_id},
             )
 
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_summary_rooms",
                 keyvalues={"group_id": group_id, "room_id": room_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_room_from_group", _remove_room_from_group_txn
         )
 
-    def get_publicised_groups_for_user(self, user_id):
-        """Get all groups a user is publicising
-        """
-        return self._simple_select_onecol(
-            table="local_group_membership",
-            keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
-            retcol="group_id",
-            desc="get_publicised_groups_for_user",
-        )
-
     def update_group_publicity(self, group_id, user_id, publicise):
         """Update whether the user is publicising their membership of the group
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="local_group_membership",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"is_publicised": publicise},
@@ -879,12 +1083,12 @@ class GroupServerStore(SQLBaseStore):
 
         def _register_user_group_membership_txn(txn, next_id):
             # TODO: Upsert?
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="local_group_membership",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="local_group_membership",
                 values={
@@ -897,7 +1101,7 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="local_group_updates",
                 values={
@@ -916,7 +1120,7 @@ class GroupServerStore(SQLBaseStore):
 
             if membership == "join":
                 if local_attestation:
-                    self._simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="group_attestations_renewals",
                         values={
@@ -926,7 +1130,7 @@ class GroupServerStore(SQLBaseStore):
                         },
                     )
                 if remote_attestation:
-                    self._simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="group_attestations_remote",
                         values={
@@ -937,12 +1141,12 @@ class GroupServerStore(SQLBaseStore):
                         },
                     )
             else:
-                self._simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="group_attestations_renewals",
                     keyvalues={"group_id": group_id, "user_id": user_id},
                 )
-                self._simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="group_attestations_remote",
                     keyvalues={"group_id": group_id, "user_id": user_id},
@@ -951,7 +1155,7 @@ class GroupServerStore(SQLBaseStore):
             return next_id
 
         with self._group_updates_id_gen.get_next() as next_id:
-            res = yield self.runInteraction(
+            res = yield self.db.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
                 next_id,
@@ -962,7 +1166,7 @@ class GroupServerStore(SQLBaseStore):
     def create_group(
         self, group_id, user_id, name, avatar_url, short_description, long_description
     ):
-        yield self._simple_insert(
+        yield self.db.simple_insert(
             table="groups",
             values={
                 "group_id": group_id,
@@ -977,33 +1181,17 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def update_group_profile(self, group_id, profile):
-        yield self._simple_update_one(
+        yield self.db.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues=profile,
             desc="update_group_profile",
         )
 
-    def get_attestations_need_renewals(self, valid_until_ms):
-        """Get all attestations that need to be renewed until givent time
-        """
-
-        def _get_attestations_need_renewals_txn(txn):
-            sql = """
-                SELECT group_id, user_id FROM group_attestations_renewals
-                WHERE valid_until_ms <= ?
-            """
-            txn.execute(sql, (valid_until_ms,))
-            return self.cursor_to_dict(txn)
-
-        return self.runInteraction(
-            "get_attestations_need_renewals", _get_attestations_need_renewals_txn
-        )
-
     def update_attestation_renewal(self, group_id, user_id, attestation):
         """Update an attestation that we have renewed
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1013,7 +1201,7 @@ class GroupServerStore(SQLBaseStore):
     def update_remote_attestion(self, group_id, user_id, attestation):
         """Update an attestation that a remote has renewed
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={
@@ -1032,118 +1220,12 @@ class GroupServerStore(SQLBaseStore):
             group_id (str)
             user_id (str)
         """
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             desc="remove_attestation_renewal",
         )
 
-    @defer.inlineCallbacks
-    def get_remote_attestation(self, group_id, user_id):
-        """Get the attestation that proves the remote agrees that the user is
-        in the group.
-        """
-        row = yield self._simple_select_one(
-            table="group_attestations_remote",
-            keyvalues={"group_id": group_id, "user_id": user_id},
-            retcols=("valid_until_ms", "attestation_json"),
-            desc="get_remote_attestation",
-            allow_none=True,
-        )
-
-        now = int(self._clock.time_msec())
-        if row and now < row["valid_until_ms"]:
-            return json.loads(row["attestation_json"])
-
-        return None
-
-    def get_joined_groups(self, user_id):
-        return self._simple_select_onecol(
-            table="local_group_membership",
-            keyvalues={"user_id": user_id, "membership": "join"},
-            retcol="group_id",
-            desc="get_joined_groups",
-        )
-
-    def get_all_groups_for_user(self, user_id, now_token):
-        def _get_all_groups_for_user_txn(txn):
-            sql = """
-                SELECT group_id, type, membership, u.content
-                FROM local_group_updates AS u
-                INNER JOIN local_group_membership USING (group_id, user_id)
-                WHERE user_id = ? AND membership != 'leave'
-                    AND stream_id <= ?
-            """
-            txn.execute(sql, (user_id, now_token))
-            return [
-                {
-                    "group_id": row[0],
-                    "type": row[1],
-                    "membership": row[2],
-                    "content": json.loads(row[3]),
-                }
-                for row in txn
-            ]
-
-        return self.runInteraction(
-            "get_all_groups_for_user", _get_all_groups_for_user_txn
-        )
-
-    def get_groups_changes_for_user(self, user_id, from_token, to_token):
-        from_token = int(from_token)
-        has_changed = self._group_updates_stream_cache.has_entity_changed(
-            user_id, from_token
-        )
-        if not has_changed:
-            return []
-
-        def _get_groups_changes_for_user_txn(txn):
-            sql = """
-                SELECT group_id, membership, type, u.content
-                FROM local_group_updates AS u
-                INNER JOIN local_group_membership USING (group_id, user_id)
-                WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
-            """
-            txn.execute(sql, (user_id, from_token, to_token))
-            return [
-                {
-                    "group_id": group_id,
-                    "membership": membership,
-                    "type": gtype,
-                    "content": json.loads(content_json),
-                }
-                for group_id, membership, gtype, content_json in txn
-            ]
-
-        return self.runInteraction(
-            "get_groups_changes_for_user", _get_groups_changes_for_user_txn
-        )
-
-    def get_all_groups_changes(self, from_token, to_token, limit):
-        from_token = int(from_token)
-        has_changed = self._group_updates_stream_cache.has_any_entity_changed(
-            from_token
-        )
-        if not has_changed:
-            return []
-
-        def _get_all_groups_changes_txn(txn):
-            sql = """
-                SELECT stream_id, group_id, user_id, type, content
-                FROM local_group_updates
-                WHERE ? < stream_id AND stream_id <= ?
-                LIMIT ?
-            """
-            txn.execute(sql, (from_token, to_token, limit))
-            return [
-                (stream_id, group_id, user_id, gtype, json.loads(content_json))
-                for stream_id, group_id, user_id, gtype, content_json in txn
-            ]
-
-        return self.runInteraction(
-            "get_all_groups_changes", _get_all_groups_changes_txn
-        )
-
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
@@ -1174,8 +1256,8 @@ class GroupServerStore(SQLBaseStore):
             ]
 
             for table in tables:
-                self._simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn, table=table, keyvalues={"group_id": group_id}
                 )
 
-        return self.runInteraction("delete_group", _delete_group_txn)
+        return self.db.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
new file mode 100644
index 0000000000..4e1642a27a
--- /dev/null
+++ b/synapse/storage/data_stores/main/keys.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+from signedjson.key import decode_verify_key_bytes
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.keys import FetchKeyResult
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
+
+logger = logging.getLogger(__name__)
+
+
+db_binary_type = memoryview
+
+
+class KeyStore(SQLBaseStore):
+    """Persistence for signature verification keys
+    """
+
+    @cached()
+    def _get_server_verify_key(self, server_name_and_key_id):
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
+    )
+    def get_server_verify_keys(self, server_name_and_key_ids):
+        """
+        Args:
+            server_name_and_key_ids (iterable[Tuple[str, str]]):
+                iterable of (server_name, key-id) tuples to fetch keys for
+
+        Returns:
+            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
+                unknown
+        """
+        keys = {}
+
+        def _get_keys(txn, batch):
+            """Processes a batch of keys to fetch, and adds the result to `keys`."""
+
+            # batch_iter always returns tuples so it's safe to do len(batch)
+            sql = (
+                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+                "FROM server_signature_keys WHERE 1=0"
+            ) + " OR (server_name=? AND key_id=?)" * len(batch)
+
+            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
+
+            for row in txn:
+                server_name, key_id, key_bytes, ts_valid_until_ms = row
+
+                if ts_valid_until_ms is None:
+                    # Old keys may be stored with a ts_valid_until_ms of null,
+                    # in which case we treat this as if it was set to `0`, i.e.
+                    # it won't match key requests that define a minimum
+                    # `ts_valid_until_ms`.
+                    ts_valid_until_ms = 0
+
+                res = FetchKeyResult(
+                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+                    valid_until_ts=ts_valid_until_ms,
+                )
+                keys[(server_name, key_id)] = res
+
+        def _txn(txn):
+            for batch in batch_iter(server_name_and_key_ids, 50):
+                _get_keys(txn, batch)
+            return keys
+
+        return self.db.runInteraction("get_server_verify_keys", _txn)
+
+    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+        """Stores NACL verification keys for remote servers.
+        Args:
+            from_server (str): Where the verification keys were looked up
+            ts_added_ms (int): The time to record that the key was added
+            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+                keys to be stored. Each entry is a triplet of
+                (server_name, key_id, key).
+        """
+        key_values = []
+        value_values = []
+        invalidations = []
+        for server_name, key_id, fetch_result in verify_keys:
+            key_values.append((server_name, key_id))
+            value_values.append(
+                (
+                    from_server,
+                    ts_added_ms,
+                    fetch_result.valid_until_ts,
+                    db_binary_type(fetch_result.verify_key.encode()),
+                )
+            )
+            # invalidate takes a tuple corresponding to the params of
+            # _get_server_verify_key. _get_server_verify_key only takes one
+            # param, which is itself the 2-tuple (server_name, key_id).
+            invalidations.append((server_name, key_id))
+
+        def _invalidate(res):
+            f = self._get_server_verify_key.invalidate
+            for i in invalidations:
+                f((i,))
+            return res
+
+        return self.db.runInteraction(
+            "store_server_verify_keys",
+            self.db.simple_upsert_many_txn,
+            table="server_signature_keys",
+            key_names=("server_name", "key_id"),
+            key_values=key_values,
+            value_names=(
+                "from_server",
+                "ts_added_ms",
+                "ts_valid_until_ms",
+                "verify_key",
+            ),
+            value_values=value_values,
+        ).addCallback(_invalidate)
+
+    def store_server_keys_json(
+        self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
+    ):
+        """Stores the JSON bytes for a set of keys from a server
+        The JSON should be signed by the originating server, the intermediate
+        server, and by this server. Updates the value for the
+        (server_name, key_id, from_server) triplet if one already existed.
+        Args:
+            server_name (str): The name of the server.
+            key_id (str): The identifer of the key this JSON is for.
+            from_server (str): The server this JSON was fetched from.
+            ts_now_ms (int): The time now in milliseconds.
+            ts_valid_until_ms (int): The time when this json stops being valid.
+            key_json (bytes): The encoded JSON.
+        """
+        return self.db.simple_upsert(
+            table="server_keys_json",
+            keyvalues={
+                "server_name": server_name,
+                "key_id": key_id,
+                "from_server": from_server,
+            },
+            values={
+                "server_name": server_name,
+                "key_id": key_id,
+                "from_server": from_server,
+                "ts_added_ms": ts_now_ms,
+                "ts_valid_until_ms": ts_expires_ms,
+                "key_json": db_binary_type(key_json_bytes),
+            },
+            desc="store_server_keys_json",
+        )
+
+    def get_server_keys_json(self, server_keys):
+        """Retrive the key json for a list of server_keys and key ids.
+        If no keys are found for a given server, key_id and source then
+        that server, key_id, and source triplet entry will be an empty list.
+        The JSON is returned as a byte array so that it can be efficiently
+        used in an HTTP response.
+        Args:
+            server_keys (list): List of (server_name, key_id, source) triplets.
+        Returns:
+            Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
+                Dict mapping (server_name, key_id, source) triplets to lists of dicts
+        """
+
+        def _get_server_keys_json_txn(txn):
+            results = {}
+            for server_name, key_id, from_server in server_keys:
+                keyvalues = {"server_name": server_name}
+                if key_id is not None:
+                    keyvalues["key_id"] = key_id
+                if from_server is not None:
+                    keyvalues["from_server"] = from_server
+                rows = self.db.simple_select_list_txn(
+                    txn,
+                    "server_keys_json",
+                    keyvalues=keyvalues,
+                    retcols=(
+                        "key_id",
+                        "from_server",
+                        "ts_added_ms",
+                        "ts_valid_until_ms",
+                        "key_json",
+                    ),
+                )
+                results[(server_name, key_id, from_server)] = rows
+            return results
+
+        return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 6b1238ce4a..8aecd414c2 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -12,16 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 
 
-class MediaRepositoryStore(BackgroundUpdateStore):
-    """Persistence for attachments and avatars"""
-
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+            database, db_conn, hs
+        )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             update_name="local_media_repository_url_idx",
             index_name="local_media_repository_url_idx",
             table="local_media_repository",
@@ -29,12 +30,19 @@ class MediaRepositoryStore(BackgroundUpdateStore):
             where_clause="url_cache IS NOT NULL",
         )
 
+
+class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
+    """Persistence for attachments and avatars"""
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
         Returns:
             None if the media_id doesn't exist.
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -59,7 +67,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
         user_id,
         url_cache=None,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "local_media_repository",
             {
                 "media_id": media_id,
@@ -119,12 +127,12 @@ class MediaRepositoryStore(BackgroundUpdateStore):
                 )
             )
 
-        return self.runInteraction("get_url_cache", get_url_cache_txn)
+        return self.db.runInteraction("get_url_cache", get_url_cache_txn)
 
     def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "local_media_repository_url_cache",
             {
                 "url": url,
@@ -139,7 +147,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
         )
 
     def get_local_media_thumbnails(self, media_id):
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             "local_media_repository_thumbnails",
             {"media_id": media_id},
             (
@@ -161,7 +169,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "local_media_repository_thumbnails",
             {
                 "media_id": media_id,
@@ -175,7 +183,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
         )
 
     def get_cached_remote_media(self, origin, media_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -200,7 +208,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
         upload_name,
         filesystem_id,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "remote_media_cache",
             {
                 "media_origin": origin,
@@ -245,10 +253,12 @@ class MediaRepositoryStore(BackgroundUpdateStore):
 
             txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+        return self.db.runInteraction(
+            "update_cached_last_access_time", update_cache_txn
+        )
 
     def get_remote_media_thumbnails(self, origin, media_id):
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             "remote_media_cache_thumbnails",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -273,7 +283,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "remote_media_cache_thumbnails",
             {
                 "media_origin": origin,
@@ -295,24 +305,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
             " WHERE last_access_ts < ?"
         )
 
-        return self._execute(
-            "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+        return self.db.execute(
+            "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
         )
 
     def delete_remote_media(self, media_origin, media_id):
         def delete_remote_media_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 "remote_media_cache",
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 "remote_media_cache_thumbnails",
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
 
-        return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+        return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
 
     def get_expired_url_cache(self, now_ts):
         sql = (
@@ -326,18 +336,20 @@ class MediaRepositoryStore(BackgroundUpdateStore):
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
-        return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+        return self.db.runInteraction(
+            "get_expired_url_cache", _get_expired_url_cache_txn
+        )
 
-    def delete_url_cache(self, media_ids):
+    async def delete_url_cache(self, media_ids):
         if len(media_ids) == 0:
             return
 
-        sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
+        sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+        return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
 
     def get_url_cache_media_before(self, before_ts):
         sql = (
@@ -351,23 +363,23 @@ class MediaRepositoryStore(BackgroundUpdateStore):
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
-    def delete_url_cache_media(self, media_ids):
+    async def delete_url_cache_media(self, media_ids):
         if len(media_ids) == 0:
             return
 
         def _delete_url_cache_media_txn(txn):
-            sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
+            sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-            sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
+            sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.runInteraction(
+        return await self.db.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
         )
diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py
new file mode 100644
index 0000000000..dad5bbc602
--- /dev/null
+++ b/synapse/storage/data_stores/main/metrics.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from collections import Counter
+
+from twisted.internet import defer
+
+from synapse.metrics import BucketCollector
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.event_push_actions import (
+    EventPushActionsWorkerStore,
+)
+from synapse.storage.database import Database
+
+
+class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
+    """Functions to pull various metrics from the DB, for e.g. phone home
+    stats and prometheus metrics.
+    """
+
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        # Collect metrics on the number of forward extremities that exist.
+        # Counter of number of extremities to count
+        self._current_forward_extremities_amount = (
+            Counter()
+        )  # type: typing.Counter[int]
+
+        BucketCollector(
+            "synapse_forward_extremities",
+            lambda: self._current_forward_extremities_amount,
+            buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+        )
+
+        # Read the extrems every 60 minutes
+        def read_forward_extremities():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "read_forward_extremities", self._read_forward_extremities
+            )
+
+        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+
+    async def _read_forward_extremities(self):
+        def fetch(txn):
+            txn.execute(
+                """
+                select count(*) c from event_forward_extremities
+                group by room_id
+                """
+            )
+            return txn.fetchall()
+
+        res = await self.db.runInteraction("read_forward_extremities", fetch)
+        self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+    @defer.inlineCallbacks
+    def count_daily_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_messages", _count_messages)
+        return ret
+
+    @defer.inlineCallbacks
+    def count_daily_sent_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then thats your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
+        return ret
+
+    @defer.inlineCallbacks
+    def count_daily_active_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
+        return ret
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 752e9788a2..e459cf49a0 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -13,13 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List
 
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
 
-from ._base import SQLBaseStore
-
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,20 +28,113 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
-class MonthlyActiveUsersStore(SQLBaseStore):
-    def __init__(self, dbconn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(None, hs)
+class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
-        self.reserved_users = ()
+
+    @cached(num_args=0)
+    def get_monthly_active_count(self):
+        """Generates current count of monthly active users
+
+        Returns:
+            Defered[int]: Number of current monthly active users
+        """
+
+        def _count_users(txn):
+            sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+            txn.execute(sql)
+            (count,) = txn.fetchone()
+            return count
+
+        return self.db.runInteraction("count_users", _count_users)
+
+    @cached(num_args=0)
+    def get_monthly_active_count_by_service(self):
+        """Generates current count of monthly active users broken down by service.
+        A service is typically an appservice but also includes native matrix users.
+        Since the `monthly_active_users` table is populated from the `user_ips` table
+        `config.track_appservice_user_ips` must be set to `true` for this
+        method to return anything other than native matrix users.
+
+        Returns:
+            Deferred[dict]: dict that includes a mapping between app_service_id
+                and the number of occurrences.
+
+        """
+
+        def _count_users_by_service(txn):
+            sql = """
+                SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+                FROM monthly_active_users
+                LEFT JOIN users ON monthly_active_users.user_id=users.name
+                GROUP BY appservice_id;
+            """
+
+            txn.execute(sql)
+            result = txn.fetchall()
+            return dict(result)
+
+        return self.db.runInteraction("count_users_by_service", _count_users_by_service)
+
+    async def get_registered_reserved_users(self) -> List[str]:
+        """Of the reserved threepids defined in config, retrieve those that are associated
+        with registered users
+
+        Returns:
+            User IDs of actual users that are reserved
+        """
+        users = []
+
+        for tp in self.hs.config.mau_limits_reserved_threepids[
+            : self.hs.config.max_mau_value
+        ]:
+            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+                tp["medium"], tp["address"]
+            )
+            if user_id:
+                users.append(user_id)
+
+        return users
+
+    @cached(num_args=1)
+    def user_last_seen_monthly_active(self, user_id):
+        """
+            Checks if a given user is part of the monthly active user group
+            Arguments:
+                user_id (str): user to add/update
+            Return:
+                Deferred[int] : timestamp since last seen, None if never seen
+
+        """
+
+        return self.db.simple_select_one_onecol(
+            table="monthly_active_users",
+            keyvalues={"user_id": user_id},
+            retcol="timestamp",
+            allow_none=True,
+            desc="user_last_seen_monthly_active",
+        )
+
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+
+        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+        self._mau_stats_only = hs.config.mau_stats_only
+        self._max_mau_value = hs.config.max_mau_value
+
         # Do not add more reserved users than the total allowable number
-        self._new_transaction(
-            dbconn,
+        # cur = LoggingTransaction(
+        self.db.new_transaction(
+            db_conn,
             "initialise_mau_threepids",
             [],
             [],
             self._initialise_reserved_users,
-            hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
+            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
         )
 
     def _initialise_reserved_users(self, txn, threepids):
@@ -51,7 +145,15 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             txn (cursor):
             threepids (list[dict]): List of threepid dicts to reserve
         """
-        reserved_user_list = []
+
+        # XXX what is this function trying to achieve?  It upserts into
+        # monthly_active_users for each *registered* reserved mau user, but why?
+        #
+        #  - shouldn't there already be an entry for each reserved user (at least
+        #    if they have been active recently)?
+        #
+        #  - if it's important that the timestamp is kept up to date, why do we only
+        #    run this at startup?
 
         for tp in threepids:
             user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@@ -59,121 +161,94 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             if user_id:
                 is_support = self.is_support_user_txn(txn, user_id)
                 if not is_support:
-                    self.upsert_monthly_active_user_txn(txn, user_id)
-                    reserved_user_list.append(user_id)
+                    # We do this manually here to avoid hitting #6791
+                    self.db.simple_upsert_txn(
+                        txn,
+                        table="monthly_active_users",
+                        keyvalues={"user_id": user_id},
+                        values={"timestamp": int(self._clock.time_msec())},
+                    )
             else:
                 logger.warning("mau limit reserved threepid %s not found in db" % tp)
-        self.reserved_users = tuple(reserved_user_list)
 
-    @defer.inlineCallbacks
-    def reap_monthly_active_users(self):
+    async def reap_monthly_active_users(self):
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
-
-        Returns:
-            Deferred[]
         """
 
-        def _reap_users(txn):
-            # Purge stale users
+        def _reap_users(txn, reserved_users):
+            """
+            Args:
+                reserved_users (tuple): reserved users to preserve
+            """
 
             thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-            query_args = [thirty_days_ago]
-            base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
-
-            # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
-            # when len(reserved_users) == 0. Works fine on sqlite.
-            if len(self.reserved_users) > 0:
-                # questionmarks is a hack to overcome sqlite not supporting
-                # tuples in 'WHERE IN %s'
-                questionmarks = "?" * len(self.reserved_users)
-
-                query_args.extend(self.reserved_users)
-                sql = base_sql + """ AND user_id NOT IN ({})""".format(
-                    ",".join(questionmarks)
-                )
-            else:
-                sql = base_sql
 
-            txn.execute(sql, query_args)
+            in_clause, in_clause_args = make_in_list_sql_clause(
+                self.database_engine, "user_id", reserved_users
+            )
+
+            txn.execute(
+                "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s"
+                % (in_clause,),
+                [thirty_days_ago] + in_clause_args,
+            )
 
-            if self.hs.config.limit_usage_by_mau:
+            if self._limit_usage_by_mau:
                 # If MAU user count still exceeds the MAU threshold, then delete on
                 # a least recently active basis.
                 # Note it is not possible to write this query using OFFSET due to
                 # incompatibilities in how sqlite and postgres support the feature.
-                # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
-                # While Postgres does not require 'LIMIT', but also does not support
+                # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present,
+                # while Postgres does not require 'LIMIT', but also does not support
                 # negative LIMIT values. So there is no way to write it that both can
                 # support
-                safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
-                # Must be greater than zero for postgres
-                safe_guard = safe_guard if safe_guard > 0 else 0
-                query_args = [safe_guard]
 
-                base_sql = """
+                # Limit must be >= 0 for postgres
+                num_of_non_reserved_users_to_remove = max(
+                    self._max_mau_value - len(reserved_users), 0
+                )
+
+                # It is important to filter reserved users twice to guard
+                # against the case where the reserved user is present in the
+                # SELECT, meaning that a legitimate mau is deleted.
+                sql = """
                     DELETE FROM monthly_active_users
                     WHERE user_id NOT IN (
                         SELECT user_id FROM monthly_active_users
+                        WHERE NOT %s
                         ORDER BY timestamp DESC
                         LIMIT ?
-                        )
-                    """
-                # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
-                # when len(reserved_users) == 0. Works fine on sqlite.
-                if len(self.reserved_users) > 0:
-                    query_args.extend(self.reserved_users)
-                    sql = base_sql + """ AND user_id NOT IN ({})""".format(
-                        ",".join(questionmarks)
                     )
-                else:
-                    sql = base_sql
-                txn.execute(sql, query_args)
-
-        yield self.runInteraction("reap_monthly_active_users", _reap_users)
-        # It seems poor to invalidate the whole cache, Postgres supports
-        # 'Returning' which would allow me to invalidate only the
-        # specific users, but sqlite has no way to do this and instead
-        # I would need to SELECT and the DELETE which without locking
-        # is racy.
-        # Have resolved to invalidate the whole cache for now and do
-        # something about it if and when the perf becomes significant
-        self.user_last_seen_monthly_active.invalidate_all()
-        self.get_monthly_active_count.invalidate_all()
-
-    @cached(num_args=0)
-    def get_monthly_active_count(self):
-        """Generates current count of monthly active users
-
-        Returns:
-            Defered[int]: Number of current monthly active users
-        """
-
-        def _count_users(txn):
-            sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
-
-            txn.execute(sql)
-            count, = txn.fetchone()
-            return count
-
-        return self.runInteraction("count_users", _count_users)
+                    AND NOT %s
+                """ % (
+                    in_clause,
+                    in_clause,
+                )
 
-    @defer.inlineCallbacks
-    def get_registered_reserved_users_count(self):
-        """Of the reserved threepids defined in config, how many are associated
-        with registered users?
+                query_args = (
+                    in_clause_args
+                    + [num_of_non_reserved_users_to_remove]
+                    + in_clause_args
+                )
+                txn.execute(sql, query_args)
 
-        Returns:
-            Defered[int]: Number of real reserved users
-        """
-        count = 0
-        for tp in self.hs.config.mau_limits_reserved_threepids:
-            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
-                tp["medium"], tp["address"]
+            # It seems poor to invalidate the whole cache. Postgres supports
+            # 'Returning' which would allow me to invalidate only the
+            # specific users, but sqlite has no way to do this and instead
+            # I would need to SELECT and the DELETE which without locking
+            # is racy.
+            # Have resolved to invalidate the whole cache for now and do
+            # something about it if and when the perf becomes significant
+            self._invalidate_all_cache_and_stream(
+                txn, self.user_last_seen_monthly_active
             )
-            if user_id:
-                count = count + 1
-        return count
+            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+
+        reserved_users = await self.get_registered_reserved_users()
+        await self.db.runInteraction(
+            "reap_monthly_active_users", _reap_users, reserved_users
+        )
 
     @defer.inlineCallbacks
     def upsert_monthly_active_user(self, user_id):
@@ -182,6 +257,9 @@ class MonthlyActiveUsersStore(SQLBaseStore):
 
         Args:
             user_id (str): user to add/update
+
+        Returns:
+            Deferred
         """
         # Support user never to be included in MAU stats. Note I can't easily call this
         # from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -195,27 +273,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         if is_support:
             return
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
-        user_in_mau = self.user_last_seen_monthly_active.cache.get(
-            (user_id,), None, update_metrics=False
-        )
-        if user_in_mau is None:
-            self.get_monthly_active_count.invalidate(())
-
-        self.user_last_seen_monthly_active.invalidate((user_id,))
-
     def upsert_monthly_active_user_txn(self, txn, user_id):
         """Updates or inserts monthly active user member
 
-        Note that, after calling this method, it will generally be necessary
-        to invalidate the caches on user_last_seen_monthly_active and
-        get_monthly_active_count. We can't do that here, because we are running
-        in a database thread rather than the main thread, and we can't call
-        txn.call_after because txn may not be a LoggingTransaction.
-
         We consciously do not call is_support_txn from this method because it
         is not possible to cache the response. is_support_txn will be false in
         almost all cases, so it seems reasonable to call it only for
@@ -239,33 +303,22 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         # never be a big table and alternative approaches (batching multiple
         # upserts into a single txn) introduced a lot of extra complexity.
         # See https://github.com/matrix-org/synapse/issues/3854 for more
-        is_insert = self._simple_upsert_txn(
+        is_insert = self.db.simple_upsert_txn(
             txn,
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
             values={"timestamp": int(self._clock.time_msec())},
         )
 
-        return is_insert
-
-    @cached(num_args=1)
-    def user_last_seen_monthly_active(self, user_id):
-        """
-            Checks if a given user is part of the monthly active user group
-            Arguments:
-                user_id (str): user to add/update
-            Return:
-                Deferred[int] : timestamp since last seen, None if never seen
-
-        """
-
-        return self._simple_select_one_onecol(
-            table="monthly_active_users",
-            keyvalues={"user_id": user_id},
-            retcol="timestamp",
-            allow_none=True,
-            desc="user_last_seen_monthly_active",
+        self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+        self._invalidate_cache_and_stream(
+            txn, self.get_monthly_active_count_by_service, ()
         )
+        self._invalidate_cache_and_stream(
+            txn, self.user_last_seen_monthly_active, (user_id,)
+        )
+
+        return is_insert
 
     @defer.inlineCallbacks
     def populate_monthly_active_users(self, user_id):
@@ -275,7 +328,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         Args:
             user_id(str): the user_id to query
         """
-        if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only:
+        if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
             is_guest = yield self.is_guest(user_id)
             if is_guest:
@@ -296,11 +349,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
                 # In the case where mau_stats_only is True and limit_usage_by_mau is
                 # False, there is no point in checking get_monthly_active_count - it
                 # adds no value and will break the logic if max_mau_value is exceeded.
-                if not self.hs.config.limit_usage_by_mau:
+                if not self._limit_usage_by_mau:
                     yield self.upsert_monthly_active_user(user_id)
                 else:
                     count = yield self.get_monthly_active_count()
-                    if count < self.hs.config.max_mau_value:
+                    if count < self._max_mau_value:
                         yield self.upsert_monthly_active_user(user_id)
             elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
                 yield self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/openid.py b/synapse/storage/data_stores/main/openid.py
index b3318045ee..cc21437e92 100644
--- a/synapse/storage/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -1,9 +1,9 @@
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 
 class OpenIdStore(SQLBaseStore):
     def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="open_id_tokens",
             values={
                 "token": token,
@@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
             else:
                 return rows[0][0]
 
-        return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
+        return self.db.runInteraction(
+            "get_user_id_for_token", get_user_id_for_token_txn
+        )
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
new file mode 100644
index 0000000000..dab31e0c2d
--- /dev/null
+++ b/synapse/storage/data_stores/main/presence.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.presence import UserPresenceState
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
+
+
+class PresenceStore(SQLBaseStore):
+    @defer.inlineCallbacks
+    def update_presence(self, presence_states):
+        stream_ordering_manager = self._presence_id_gen.get_next_mult(
+            len(presence_states)
+        )
+
+        with stream_ordering_manager as stream_orderings:
+            yield self.db.runInteraction(
+                "update_presence",
+                self._update_presence_txn,
+                stream_orderings,
+                presence_states,
+            )
+
+        return stream_orderings[-1], self._presence_id_gen.get_current_token()
+
+    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+        for stream_id, state in zip(stream_orderings, presence_states):
+            txn.call_after(
+                self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
+            )
+            txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
+
+        # Actually insert new rows
+        self.db.simple_insert_many_txn(
+            txn,
+            table="presence_stream",
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": state.user_id,
+                    "state": state.state,
+                    "last_active_ts": state.last_active_ts,
+                    "last_federation_update_ts": state.last_federation_update_ts,
+                    "last_user_sync_ts": state.last_user_sync_ts,
+                    "status_msg": state.status_msg,
+                    "currently_active": state.currently_active,
+                }
+                for stream_id, state in zip(stream_orderings, presence_states)
+            ],
+        )
+
+        # Delete old rows to stop database from getting really big
+        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
+
+        for states in batch_iter(presence_states, 50):
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "user_id", [s.user_id for s in states]
+            )
+            txn.execute(sql + clause, [stream_id] + list(args))
+
+    def get_all_presence_updates(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_presence_updates_txn(txn):
+            sql = """
+                SELECT stream_id, user_id, state, last_active_ts,
+                    last_federation_update_ts, last_user_sync_ts,
+                    status_msg,
+                currently_active
+                FROM presence_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_presence_updates", get_all_presence_updates_txn
+        )
+
+    @cached()
+    def _get_presence_for_user(self, user_id):
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_presence_for_user",
+        list_name="user_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def get_presence_for_users(self, user_ids):
+        rows = yield self.db.simple_select_many_batch(
+            table="presence_stream",
+            column="user_id",
+            iterable=user_ids,
+            keyvalues={},
+            retcols=(
+                "user_id",
+                "state",
+                "last_active_ts",
+                "last_federation_update_ts",
+                "last_user_sync_ts",
+                "status_msg",
+                "currently_active",
+            ),
+            desc="get_presence_for_users",
+        )
+
+        for row in rows:
+            row["currently_active"] = bool(row["currently_active"])
+
+        return {row["user_id"]: UserPresenceState(**row) for row in rows}
+
+    def get_current_presence_token(self):
+        return self._presence_id_gen.get_current_token()
+
+    def allow_presence_visible(self, observed_localpart, observer_userid):
+        return self.db.simple_insert(
+            table="presence_allow_inbound",
+            values={
+                "observed_user_id": observed_localpart,
+                "observer_user_id": observer_userid,
+            },
+            desc="allow_presence_visible",
+            or_ignore=True,
+        )
+
+    def disallow_presence_visible(self, observed_localpart, observer_userid):
+        return self.db.simple_delete_one(
+            table="presence_allow_inbound",
+            keyvalues={
+                "observed_user_id": observed_localpart,
+                "observer_user_id": observer_userid,
+            },
+            desc="disallow_presence_visible",
+        )
diff --git a/synapse/storage/profile.py b/synapse/storage/data_stores/main/profile.py
index 912c1df6be..bfc9369f0b 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -16,16 +16,15 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.storage.roommember import ProfileInfo
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.roommember import ProfileInfo
 
 
 class ProfileWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_profileinfo(self, user_localpart):
         try:
-            profile = yield self._simple_select_one(
+            profile = yield self.db.simple_select_one(
                 table="profiles",
                 keyvalues={"user_id": user_localpart},
                 retcols=("displayname", "avatar_url"),
@@ -43,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_profile_displayname(self, user_localpart):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="displayname",
@@ -51,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_profile_avatar_url(self, user_localpart):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="avatar_url",
@@ -59,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_from_remote_profile_cache(self, user_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             retcols=("displayname", "avatar_url"),
@@ -68,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def create_profile(self, user_localpart):
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="profiles", values={"user_id": user_localpart}, desc="create_profile"
         )
 
     def set_profile_displayname(self, user_localpart, new_displayname):
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"displayname": new_displayname},
@@ -81,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"avatar_url": new_avatar_url},
@@ -96,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
         This should only be called when `is_subscribed_remote_profile_for_user`
         would return true for the user.
         """
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             values={
@@ -108,10 +107,10 @@ class ProfileStore(ProfileWorkerStore):
         )
 
     def update_remote_profile_cache(self, user_id, displayname, avatar_url):
-        return self._simple_update(
+        return self.db.simple_update(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
-            values={
+            updatevalues={
                 "displayname": displayname,
                 "avatar_url": avatar_url,
                 "last_check": self._clock.time_msec(),
@@ -126,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
         """
         subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
         if not subscribed:
-            yield self._simple_delete(
+            yield self.db.simple_delete(
                 table="remote_profile_cache",
                 keyvalues={"user_id": user_id},
                 desc="delete_remote_profile_cache",
@@ -145,9 +144,9 @@ class ProfileStore(ProfileWorkerStore):
 
             txn.execute(sql, (last_checked,))
 
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_remote_profile_cache_entries_that_expire",
             _get_remote_profile_cache_entries_that_expire_txn,
         )
@@ -156,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
     def is_subscribed_remote_profile_for_user(self, user_id):
         """Check whether we are interested in a remote user's profile.
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="group_users",
             keyvalues={"user_id": user_id},
             retcol="user_id",
@@ -167,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
         if res:
             return True
 
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"user_id": user_id},
             retcol="user_id",
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
new file mode 100644
index 0000000000..a93e1ef198
--- /dev/null
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -0,0 +1,399 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Tuple
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.types import RoomStreamToken
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+    def purge_history(self, room_id, token, delete_local_events):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
+
+            token (str): A topological token to delete events before
+
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
+
+        Returns:
+            Deferred[set[int]]: The set of state groups that are referenced by
+            deleted events.
+        """
+
+        return self.db.runInteraction(
+            "purge_history",
+            self._purge_history_txn,
+            room_id,
+            token,
+            delete_local_events,
+        )
+
+    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
+        token = RoomStreamToken.parse(token_str)
+
+        # Tables that should be pruned:
+        #     event_auth
+        #     event_backward_extremities
+        #     event_edges
+        #     event_forward_extremities
+        #     event_json
+        #     event_push_actions
+        #     event_reference_hashes
+        #     event_search
+        #     event_to_state_groups
+        #     events
+        #     rejections
+        #     room_depth
+        #     state_groups
+        #     state_groups_state
+
+        # we will build a temporary table listing the events so that we don't
+        # have to keep shovelling the list back and forth across the
+        # connection. Annoyingly the python sqlite driver commits the
+        # transaction on CREATE, so let's do this first.
+        #
+        # furthermore, we might already have the table from a previous (failed)
+        # purge attempt, so let's drop the table first.
+
+        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+        txn.execute(
+            "CREATE TEMPORARY TABLE events_to_purge ("
+            "    event_id TEXT NOT NULL,"
+            "    should_delete BOOLEAN NOT NULL"
+            ")"
+        )
+
+        # First ensure that we're not about to delete all the forward extremeties
+        txn.execute(
+            "SELECT e.event_id, e.depth FROM events as e "
+            "INNER JOIN event_forward_extremities as f "
+            "ON e.event_id = f.event_id "
+            "AND e.room_id = f.room_id "
+            "WHERE f.room_id = ?",
+            (room_id,),
+        )
+        rows = txn.fetchall()
+        max_depth = max(row[1] for row in rows)
+
+        if max_depth < token.topological:
+            # We need to ensure we don't delete all the events from the database
+            # otherwise we wouldn't be able to send any events (due to not
+            # having any backwards extremeties)
+            raise SynapseError(
+                400, "topological_ordering is greater than forward extremeties"
+            )
+
+        logger.info("[purge] looking for events to delete")
+
+        should_delete_expr = "state_key IS NULL"
+        should_delete_params = ()  # type: Tuple[Any, ...]
+        if not delete_local_events:
+            should_delete_expr += " AND event_id NOT LIKE ?"
+
+            # We include the parameter twice since we use the expression twice
+            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
+
+        should_delete_params += (room_id, token.topological)
+
+        # Note that we insert events that are outliers and aren't going to be
+        # deleted, as nothing will happen to them.
+        txn.execute(
+            "INSERT INTO events_to_purge"
+            " SELECT event_id, %s"
+            " FROM events AS e LEFT JOIN state_events USING (event_id)"
+            " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+            % (should_delete_expr, should_delete_expr),
+            should_delete_params,
+        )
+
+        # We create the indices *after* insertion as that's a lot faster.
+
+        # create an index on should_delete because later we'll be looking for
+        # the should_delete / shouldn't_delete subsets
+        txn.execute(
+            "CREATE INDEX events_to_purge_should_delete"
+            " ON events_to_purge(should_delete)"
+        )
+
+        # We do joins against events_to_purge for e.g. calculating state
+        # groups to purge, etc., so lets make an index.
+        txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
+
+        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
+        event_rows = txn.fetchall()
+        logger.info(
+            "[purge] found %i events before cutoff, of which %i can be deleted",
+            len(event_rows),
+            sum(1 for e in event_rows if e[1]),
+        )
+
+        logger.info("[purge] Finding new backward extremities")
+
+        # We calculate the new entries for the backward extremeties by finding
+        # events to be purged that are pointed to by events we're not going to
+        # purge.
+        txn.execute(
+            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+            " WHERE ep2.event_id IS NULL"
+        )
+        new_backwards_extrems = txn.fetchall()
+
+        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
+        txn.execute(
+            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
+        )
+
+        # Update backward extremeties
+        txn.executemany(
+            "INSERT INTO event_backward_extremities (room_id, event_id)"
+            " VALUES (?, ?)",
+            [(room_id, event_id) for event_id, in new_backwards_extrems],
+        )
+
+        logger.info("[purge] finding state groups referenced by deleted events")
+
+        # Get all state groups that are referenced by events that are to be
+        # deleted.
+        txn.execute(
+            """
+            SELECT DISTINCT state_group FROM events_to_purge
+            INNER JOIN event_to_state_groups USING (event_id)
+        """
+        )
+
+        referenced_state_groups = {sg for sg, in txn}
+        logger.info(
+            "[purge] found %i referenced state groups", len(referenced_state_groups)
+        )
+
+        logger.info("[purge] removing events from event_to_state_groups")
+        txn.execute(
+            "DELETE FROM event_to_state_groups "
+            "WHERE event_id IN (SELECT event_id from events_to_purge)"
+        )
+        for event_id, _ in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+
+        # Delete all remote non-state events
+        for table in (
+            "events",
+            "event_json",
+            "event_auth",
+            "event_edges",
+            "event_forward_extremities",
+            "event_reference_hashes",
+            "event_search",
+            "rejections",
+        ):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,)
+            )
+
+        # event_push_actions lacks an index on event_id, and has one on
+        # (room_id, event_id) instead.
+        for table in ("event_push_actions",):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+                (room_id,),
+            )
+
+        # Mark all state and own events as outliers
+        logger.info("[purge] marking remaining events as outliers")
+        txn.execute(
+            "UPDATE events SET outlier = ?"
+            " WHERE event_id IN ("
+            "    SELECT event_id FROM events_to_purge "
+            "    WHERE NOT should_delete"
+            ")",
+            (True,),
+        )
+
+        # synapse tries to take out an exclusive lock on room_depth whenever it
+        # persists events (because upsert), and once we run this update, we
+        # will block that for the rest of our transaction.
+        #
+        # So, let's stick it at the end so that we don't block event
+        # persistence.
+        #
+        # We do this by calculating the minimum depth of the backwards
+        # extremities. However, the events in event_backward_extremities
+        # are ones we don't have yet so we need to look at the events that
+        # point to it via event_edges table.
+        txn.execute(
+            """
+            SELECT COALESCE(MIN(depth), 0)
+            FROM event_backward_extremities AS eb
+            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+            INNER JOIN events AS e ON e.event_id = eg.event_id
+            WHERE eb.room_id = ?
+        """,
+            (room_id,),
+        )
+        (min_depth,) = txn.fetchone()
+
+        logger.info("[purge] updating room_depth to %d", min_depth)
+
+        txn.execute(
+            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+            (min_depth, room_id),
+        )
+
+        # finally, drop the temp table. this will commit the txn in sqlite,
+        # so make sure to keep this actually last.
+        txn.execute("DROP TABLE events_to_purge")
+
+        logger.info("[purge] done")
+
+        return referenced_state_groups
+
+    def purge_room(self, room_id):
+        """Deletes all record of a room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[List[int]]: The list of state groups to delete.
+        """
+
+        return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+
+    def _purge_room_txn(self, txn, room_id):
+        # First we fetch all the state groups that should be deleted, before
+        # we delete that information.
+        txn.execute(
+            """
+                SELECT DISTINCT state_group FROM events
+                INNER JOIN event_to_state_groups USING(event_id)
+                WHERE events.room_id = ?
+            """,
+            (room_id,),
+        )
+
+        state_groups = [row[0] for row in txn]
+
+        # Now we delete tables which lack an index on room_id but have one on event_id
+        for table in (
+            "event_auth",
+            "event_edges",
+            "event_push_actions_staging",
+            "event_reference_hashes",
+            "event_relations",
+            "event_to_state_groups",
+            "redactions",
+            "rejections",
+            "state_events",
+        ):
+            logger.info("[purge] removing %s from %s", room_id, table)
+
+            txn.execute(
+                """
+                DELETE FROM %s WHERE event_id IN (
+                  SELECT event_id FROM events WHERE room_id=?
+                )
+                """
+                % (table,),
+                (room_id,),
+            )
+
+        # and finally, the tables with an index on room_id (or no useful index)
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            # no useful index, but let's clear them anyway
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "local_invites",
+            "room_account_data",
+            "room_tags",
+            "local_current_membership",
+        ):
+            logger.info("[purge] removing %s from %s", room_id, table)
+            txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
+
+        # Other tables we do NOT need to clear out:
+        #
+        #  - blocked_rooms
+        #    This is important, to make sure that we don't accidentally rejoin a blocked
+        #    room after it was purged
+        #
+        #  - user_directory
+        #    This has a room_id column, but it is unused
+        #
+
+        # Other tables that we might want to consider clearing out include:
+        #
+        #  - event_reports
+        #       Given that these are intended for abuse management my initial
+        #       inclination is to leave them in place.
+        #
+        #  - current_state_delta_stream
+        #  - ex_outlier_stream
+        #  - room_tags_revisions
+        #       The problem with these is that they are largeish and there is no room_id
+        #       index on them. In any case we should be clearing out 'stream' tables
+        #       periodically anyway (#5888)
+
+        # TODO: we could probably usefully do a bunch of cache invalidation here
+
+        logger.info("[purge] done")
+
+        return state_groups
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
new file mode 100644
index 0000000000..ef8f40959f
--- /dev/null
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -0,0 +1,729 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+from typing import Union
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.push.baserules import list_with_base_rules
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+logger = logging.getLogger(__name__)
+
+
+def _load_rules(rawrules, enabled_map):
+    ruleslist = []
+    for rawrule in rawrules:
+        rule = dict(rawrule)
+        rule["conditions"] = json.loads(rawrule["conditions"])
+        rule["actions"] = json.loads(rawrule["actions"])
+        rule["default"] = False
+        ruleslist.append(rule)
+
+    # We're going to be mutating this a lot, so do a deep copy
+    rules = list(list_with_base_rules(ruleslist))
+
+    for i, rule in enumerate(rules):
+        rule_id = rule["rule_id"]
+        if rule_id in enabled_map:
+            if rule.get("enabled", True) != bool(enabled_map[rule_id]):
+                # Rules are cached across users.
+                rule = dict(rule)
+                rule["enabled"] = bool(enabled_map[rule_id])
+                rules[i] = rule
+
+    return rules
+
+
+class PushRulesWorkerStore(
+    ApplicationServiceWorkerStore,
+    ReceiptsWorkerStore,
+    PusherWorkerStore,
+    RoomMemberWorkerStore,
+    EventsWorkerStore,
+    SQLBaseStore,
+):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+
+        if hs.config.worker.worker_app is None:
+            self._push_rules_stream_id_gen = ChainedIdGenerator(
+                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+        else:
+            self._push_rules_stream_id_gen = SlavedIdTracker(
+                db_conn, "push_rules_stream", "stream_id"
+            )
+
+        push_rules_prefill, push_rules_id = self.db.get_cache_dict(
+            db_conn,
+            "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache",
+            push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def get_push_rules_for_user(self, user_id):
+        rows = yield self.db.simple_select_list(
+            table="push_rules",
+            keyvalues={"user_name": user_id},
+            retcols=(
+                "user_name",
+                "rule_id",
+                "priority_class",
+                "priority",
+                "conditions",
+                "actions",
+            ),
+            desc="get_push_rules_enabled_for_user",
+        )
+
+        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+        enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+
+        rules = _load_rules(rows, enabled_map)
+
+        return rules
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def get_push_rules_enabled_for_user(self, user_id):
+        results = yield self.db.simple_select_list(
+            table="push_rules_enable",
+            keyvalues={"user_name": user_id},
+            retcols=("user_name", "rule_id", "enabled"),
+            desc="get_push_rules_enabled_for_user",
+        )
+        return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
+
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                (count,) = txn.fetchone()
+                return bool(count)
+
+            return self.db.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
+    @cachedList(
+        cached_method_name="get_push_rules_for_user",
+        list_name="user_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def bulk_get_push_rules(self, user_ids):
+        if not user_ids:
+            return {}
+
+        results = {user_id: [] for user_id in user_ids}
+
+        rows = yield self.db.simple_select_many_batch(
+            table="push_rules",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("*",),
+            desc="bulk_get_push_rules",
+        )
+
+        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+        for row in rows:
+            results.setdefault(row["user_name"], []).append(row)
+
+        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+
+        for user_id, rules in results.items():
+            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+
+        return results
+
+    @defer.inlineCallbacks
+    def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+        """Copy a single push rule from one room to another for a specific user.
+
+        Args:
+            new_room_id (str): ID of the new room.
+            user_id (str): ID of user the push rule belongs to.
+            rule (Dict): A push rule.
+        """
+        # Create new rule id
+        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        new_rule_id = rule_id_scope + "/" + new_room_id
+
+        # Change room id in each condition
+        for condition in rule.get("conditions", []):
+            if condition.get("key") == "room_id":
+                condition["pattern"] = new_room_id
+
+        # Add the rule for the new room
+        yield self.add_push_rule(
+            user_id=user_id,
+            rule_id=new_rule_id,
+            priority_class=rule["priority_class"],
+            conditions=rule["conditions"],
+            actions=rule["actions"],
+        )
+
+    @defer.inlineCallbacks
+    def copy_push_rules_from_room_to_room_for_user(
+        self, old_room_id, new_room_id, user_id
+    ):
+        """Copy all of the push rules from one room to another for a specific
+        user.
+
+        Args:
+            old_room_id (str): ID of the old room.
+            new_room_id (str): ID of the new room.
+            user_id (str): ID of user to copy push rules for.
+        """
+        # Retrieve push rules for this user
+        user_push_rules = yield self.get_push_rules_for_user(user_id)
+
+        # Get rules relating to the old room and copy them to the new room
+        for rule in user_push_rules:
+            conditions = rule.get("conditions", [])
+            if any(
+                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+                for c in conditions
+            ):
+                yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+
+    @defer.inlineCallbacks
+    def bulk_get_push_rules_for_room(self, event, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        current_state_ids = yield context.get_current_state_ids()
+        result = yield self._bulk_get_push_rules_for_room(
+            event.room_id, state_group, current_state_ids, event=event
+        )
+        return result
+
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    def _bulk_get_push_rules_for_room(
+        self, room_id, state_group, current_state_ids, cache_context, event=None
+    ):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        # We also will want to generate notifs for other people in the room so
+        # their unread countss are correct in the event stream, but to avoid
+        # generating them for bot / AS users etc, we only do so for people who've
+        # sent a read receipt into the room.
+
+        users_in_room = yield self._get_joined_users_from_context(
+            room_id,
+            state_group,
+            current_state_ids,
+            on_invalidate=cache_context.invalidate,
+            event=event,
+        )
+
+        # We ignore app service users for now. This is so that we don't fill
+        # up the `get_if_users_have_pushers` cache with AS entries that we
+        # know don't have pushers, nor even read receipts.
+        local_users_in_room = {
+            u
+            for u in users_in_room
+            if self.hs.is_mine_id(u)
+            and not self.get_if_app_services_interested_in_user(u)
+        }
+
+        # users in the room who have pushers need to get push rules run because
+        # that's how their pushers work
+        if_users_with_pushers = yield self.get_if_users_have_pushers(
+            local_users_in_room, on_invalidate=cache_context.invalidate
+        )
+        user_ids = {
+            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+        }
+
+        users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
+
+        # any users with pushers must be ours: they have pushers
+        for uid in users_with_receipts:
+            if uid in local_users_in_room:
+                user_ids.add(uid)
+
+        rules_by_user = yield self.bulk_get_push_rules(
+            user_ids, on_invalidate=cache_context.invalidate
+        )
+
+        rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+        return rules_by_user
+
+    @cachedList(
+        cached_method_name="get_push_rules_enabled_for_user",
+        list_name="user_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def bulk_get_push_rules_enabled(self, user_ids):
+        if not user_ids:
+            return {}
+
+        results = {user_id: {} for user_id in user_ids}
+
+        rows = yield self.db.simple_select_many_batch(
+            table="push_rules_enable",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("user_name", "rule_id", "enabled"),
+            desc="bulk_get_push_rules_enabled",
+        )
+        for row in rows:
+            enabled = bool(row["enabled"])
+            results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+        return results
+
+    def get_all_push_rule_updates(self, last_id, current_id, limit):
+        """Get all the push rules changes that have happend on the server"""
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_push_rule_updates_txn(txn):
+            sql = (
+                "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+                " op, priority_class, priority, conditions, actions"
+                " FROM push_rules_stream"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_push_rule_updates", get_all_push_rule_updates_txn
+        )
+
+
+class PushRuleStore(PushRulesWorkerStore):
+    @defer.inlineCallbacks
+    def add_push_rule(
+        self,
+        user_id,
+        rule_id,
+        priority_class,
+        conditions,
+        actions,
+        before=None,
+        after=None,
+    ):
+        conditions_json = json.dumps(conditions)
+        actions_json = json.dumps(actions)
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            if before or after:
+                yield self.db.runInteraction(
+                    "_add_push_rule_relative_txn",
+                    self._add_push_rule_relative_txn,
+                    stream_id,
+                    event_stream_ordering,
+                    user_id,
+                    rule_id,
+                    priority_class,
+                    conditions_json,
+                    actions_json,
+                    before,
+                    after,
+                )
+            else:
+                yield self.db.runInteraction(
+                    "_add_push_rule_highest_priority_txn",
+                    self._add_push_rule_highest_priority_txn,
+                    stream_id,
+                    event_stream_ordering,
+                    user_id,
+                    rule_id,
+                    priority_class,
+                    conditions_json,
+                    actions_json,
+                )
+
+    def _add_push_rule_relative_txn(
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        priority_class,
+        conditions_json,
+        actions_json,
+        before,
+        after,
+    ):
+        # Lock the table since otherwise we'll have annoying races between the
+        # SELECT here and the UPSERT below.
+        self.database_engine.lock_table(txn, "push_rules")
+
+        relative_to_rule = before or after
+
+        res = self.db.simple_select_one_txn(
+            txn,
+            table="push_rules",
+            keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
+            retcols=["priority_class", "priority"],
+            allow_none=True,
+        )
+
+        if not res:
+            raise RuleNotFoundException(
+                "before/after rule not found: %s" % (relative_to_rule,)
+            )
+
+        base_priority_class = res["priority_class"]
+        base_rule_priority = res["priority"]
+
+        if base_priority_class != priority_class:
+            raise InconsistentRuleException(
+                "Given priority class does not match class of relative rule"
+            )
+
+        if before:
+            # Higher priority rules are executed first, So adding a rule before
+            # a rule means giving it a higher priority than that rule.
+            new_rule_priority = base_rule_priority + 1
+        else:
+            # We increment the priority of the existing rules to make space for
+            # the new rule. Therefore if we want this rule to appear after
+            # an existing rule we give it the priority of the existing rule,
+            # and then increment the priority of the existing rule.
+            new_rule_priority = base_rule_priority
+
+        sql = (
+            "UPDATE push_rules SET priority = priority + 1"
+            " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
+        )
+
+        txn.execute(sql, (user_id, priority_class, new_rule_priority))
+
+        self._upsert_push_rule_txn(
+            txn,
+            stream_id,
+            event_stream_ordering,
+            user_id,
+            rule_id,
+            priority_class,
+            new_rule_priority,
+            conditions_json,
+            actions_json,
+        )
+
+    def _add_push_rule_highest_priority_txn(
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        priority_class,
+        conditions_json,
+        actions_json,
+    ):
+        # Lock the table since otherwise we'll have annoying races between the
+        # SELECT here and the UPSERT below.
+        self.database_engine.lock_table(txn, "push_rules")
+
+        # find the highest priority rule in that class
+        sql = (
+            "SELECT COUNT(*), MAX(priority) FROM push_rules"
+            " WHERE user_name = ? and priority_class = ?"
+        )
+        txn.execute(sql, (user_id, priority_class))
+        res = txn.fetchall()
+        (how_many, highest_prio) = res[0]
+
+        new_prio = 0
+        if how_many > 0:
+            new_prio = highest_prio + 1
+
+        self._upsert_push_rule_txn(
+            txn,
+            stream_id,
+            event_stream_ordering,
+            user_id,
+            rule_id,
+            priority_class,
+            new_prio,
+            conditions_json,
+            actions_json,
+        )
+
+    def _upsert_push_rule_txn(
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        priority_class,
+        priority,
+        conditions_json,
+        actions_json,
+        update_stream=True,
+    ):
+        """Specialised version of simple_upsert_txn that picks a push_rule_id
+        using the _push_rule_id_gen if it needs to insert the rule. It assumes
+        that the "push_rules" table is locked"""
+
+        sql = (
+            "UPDATE push_rules"
+            " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+            " WHERE user_name = ? AND rule_id = ?"
+        )
+
+        txn.execute(
+            sql,
+            (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
+        )
+
+        if txn.rowcount == 0:
+            # We didn't update a row with the given rule_id so insert one
+            push_rule_id = self._push_rule_id_gen.get_next()
+
+            self.db.simple_insert_txn(
+                txn,
+                table="push_rules",
+                values={
+                    "id": push_rule_id,
+                    "user_name": user_id,
+                    "rule_id": rule_id,
+                    "priority_class": priority_class,
+                    "priority": priority,
+                    "conditions": conditions_json,
+                    "actions": actions_json,
+                },
+            )
+
+        if update_stream:
+            self._insert_push_rules_update_txn(
+                txn,
+                stream_id,
+                event_stream_ordering,
+                user_id,
+                rule_id,
+                op="ADD",
+                data={
+                    "priority_class": priority_class,
+                    "priority": priority,
+                    "conditions": conditions_json,
+                    "actions": actions_json,
+                },
+            )
+
+    @defer.inlineCallbacks
+    def delete_push_rule(self, user_id, rule_id):
+        """
+        Delete a push rule. Args specify the row to be deleted and can be
+        any of the columns in the push_rule table, but below are the
+        standard ones
+
+        Args:
+            user_id (str): The matrix ID of the push rule owner
+            rule_id (str): The rule_id of the rule to be deleted
+        """
+
+        def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+            self.db.simple_delete_one_txn(
+                txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
+            )
+
+            self._insert_push_rules_update_txn(
+                txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
+            )
+
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            yield self.db.runInteraction(
+                "delete_push_rule",
+                delete_push_rule_txn,
+                stream_id,
+                event_stream_ordering,
+            )
+
+    @defer.inlineCallbacks
+    def set_push_rule_enabled(self, user_id, rule_id, enabled):
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            yield self.db.runInteraction(
+                "_set_push_rule_enabled_txn",
+                self._set_push_rule_enabled_txn,
+                stream_id,
+                event_stream_ordering,
+                user_id,
+                rule_id,
+                enabled,
+            )
+
+    def _set_push_rule_enabled_txn(
+        self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+    ):
+        new_id = self._push_rules_enable_id_gen.get_next()
+        self.db.simple_upsert_txn(
+            txn,
+            "push_rules_enable",
+            {"user_name": user_id, "rule_id": rule_id},
+            {"enabled": 1 if enabled else 0},
+            {"id": new_id},
+        )
+
+        self._insert_push_rules_update_txn(
+            txn,
+            stream_id,
+            event_stream_ordering,
+            user_id,
+            rule_id,
+            op="ENABLE" if enabled else "DISABLE",
+        )
+
+    @defer.inlineCallbacks
+    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+        actions_json = json.dumps(actions)
+
+        def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+            if is_default_rule:
+                # Add a dummy rule to the rules table with the user specified
+                # actions.
+                priority_class = -1
+                priority = 1
+                self._upsert_push_rule_txn(
+                    txn,
+                    stream_id,
+                    event_stream_ordering,
+                    user_id,
+                    rule_id,
+                    priority_class,
+                    priority,
+                    "[]",
+                    actions_json,
+                    update_stream=False,
+                )
+            else:
+                self.db.simple_update_one_txn(
+                    txn,
+                    "push_rules",
+                    {"user_name": user_id, "rule_id": rule_id},
+                    {"actions": actions_json},
+                )
+
+            self._insert_push_rules_update_txn(
+                txn,
+                stream_id,
+                event_stream_ordering,
+                user_id,
+                rule_id,
+                op="ACTIONS",
+                data={"actions": actions_json},
+            )
+
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            yield self.db.runInteraction(
+                "set_push_rule_actions",
+                set_push_rule_actions_txn,
+                stream_id,
+                event_stream_ordering,
+            )
+
+    def _insert_push_rules_update_txn(
+        self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
+    ):
+        values = {
+            "stream_id": stream_id,
+            "event_stream_ordering": event_stream_ordering,
+            "user_id": user_id,
+            "rule_id": rule_id,
+            "op": op,
+        }
+        if data is not None:
+            values.update(data)
+
+        self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
+
+        txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
+        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
+        txn.call_after(
+            self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
+        )
+
+    def get_push_rules_stream_token(self):
+        """Get the position of the push rules stream.
+        Returns a pair of a stream id for the push_rules stream and the
+        room stream ordering it corresponds to."""
+        return self._push_rules_stream_id_gen.get_current_token()
+
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
diff --git a/synapse/storage/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 3e0e834a62..547b9d69cb 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,52 +15,42 @@
 # limitations under the License.
 
 import logging
-
-import six
+from typing import Iterable, Iterator
 
 from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 
-from ._base import SQLBaseStore
-
 logger = logging.getLogger(__name__)
 
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
-
 
 class PusherWorkerStore(SQLBaseStore):
-    def _decode_pushers_rows(self, rows):
+    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+        """JSON-decode the data in the rows returned from the `pushers` table
+
+        Drops any rows whose data cannot be decoded
+        """
         for r in rows:
             dataJson = r["data"]
-            r["data"] = None
             try:
-                if isinstance(dataJson, db_binary_type):
-                    dataJson = str(dataJson).decode("UTF8")
-
                 r["data"] = json.loads(dataJson)
             except Exception as e:
-                logger.warn(
+                logger.warning(
                     "Invalid JSON in data for pusher %d: %s, %s",
                     r["id"],
                     dataJson,
                     e.args[0],
                 )
-                pass
-
-            if isinstance(r["pushkey"], db_binary_type):
-                r["pushkey"] = str(r["pushkey"]).decode("UTF8")
+                continue
 
-        return rows
+            yield r
 
     @defer.inlineCallbacks
     def user_has_pusher(self, user_id):
-        ret = yield self._simple_select_one_onecol(
+        ret = yield self.db.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
@@ -73,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_pushers_by(self, keyvalues):
-        ret = yield self._simple_select_list(
+        ret = yield self.db.simple_select_list(
             "pushers",
             keyvalues,
             [
@@ -101,11 +91,11 @@ class PusherWorkerStore(SQLBaseStore):
     def get_all_pushers(self):
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        rows = yield self.runInteraction("get_all_pushers", get_pushers)
+        rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
         return rows
 
     def get_all_updated_pushers(self, last_id, current_id, limit):
@@ -135,7 +125,7 @@ class PusherWorkerStore(SQLBaseStore):
 
             return updated, deleted
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_pushers", get_all_updated_pushers_txn
         )
 
@@ -178,7 +168,7 @@ class PusherWorkerStore(SQLBaseStore):
 
             return results
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
@@ -194,7 +184,7 @@ class PusherWorkerStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_if_users_have_pushers(self, user_ids):
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="pushers",
             column="user_name",
             iterable=user_ids,
@@ -207,6 +197,84 @@ class PusherWorkerStore(SQLBaseStore):
 
         return result
 
+    @defer.inlineCallbacks
+    def update_pusher_last_stream_ordering(
+        self, app_id, pushkey, user_id, last_stream_ordering
+    ):
+        yield self.db.simple_update_one(
+            "pushers",
+            {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            {"last_stream_ordering": last_stream_ordering},
+            desc="update_pusher_last_stream_ordering",
+        )
+
+    @defer.inlineCallbacks
+    def update_pusher_last_stream_ordering_and_success(
+        self, app_id, pushkey, user_id, last_stream_ordering, last_success
+    ):
+        """Update the last stream ordering position we've processed up to for
+        the given pusher.
+
+        Args:
+            app_id (str)
+            pushkey (str)
+            last_stream_ordering (int)
+            last_success (int)
+
+        Returns:
+            Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+        """
+        updated = yield self.db.simple_update(
+            table="pushers",
+            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            updatevalues={
+                "last_stream_ordering": last_stream_ordering,
+                "last_success": last_success,
+            },
+            desc="update_pusher_last_stream_ordering_and_success",
+        )
+
+        return bool(updated)
+
+    @defer.inlineCallbacks
+    def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
+        yield self.db.simple_update(
+            table="pushers",
+            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+            updatevalues={"failing_since": failing_since},
+            desc="update_pusher_failing_since",
+        )
+
+    @defer.inlineCallbacks
+    def get_throttle_params_by_room(self, pusher_id):
+        res = yield self.db.simple_select_list(
+            "pusher_throttle",
+            {"pusher": pusher_id},
+            ["room_id", "last_sent_ts", "throttle_ms"],
+            desc="get_throttle_params_by_room",
+        )
+
+        params_by_room = {}
+        for row in res:
+            params_by_room[row["room_id"]] = {
+                "last_sent_ts": row["last_sent_ts"],
+                "throttle_ms": row["throttle_ms"],
+            }
+
+        return params_by_room
+
+    @defer.inlineCallbacks
+    def set_throttle_params(self, pusher_id, room_id, params):
+        # no need to lock because `pusher_throttle` has a primary key on
+        # (pusher, room_id) so simple_upsert will retry
+        yield self.db.simple_upsert(
+            "pusher_throttle",
+            {"pusher": pusher_id, "room_id": room_id},
+            params,
+            desc="set_throttle_params",
+            lock=False,
+        )
+
 
 class PusherStore(PusherWorkerStore):
     def get_pushers_stream_token(self):
@@ -230,8 +298,8 @@ class PusherStore(PusherWorkerStore):
     ):
         with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
-            # (app_id, pushkey, user_name) so _simple_upsert will retry
-            yield self._simple_upsert(
+            # (app_id, pushkey, user_name) so simple_upsert will retry
+            yield self.db.simple_upsert(
                 table="pushers",
                 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
                 values={
@@ -241,7 +309,7 @@ class PusherStore(PusherWorkerStore):
                     "device_display_name": device_display_name,
                     "ts": pushkey_ts,
                     "lang": lang,
-                    "data": encode_canonical_json(data),
+                    "data": bytearray(encode_canonical_json(data)),
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
@@ -256,7 +324,7 @@ class PusherStore(PusherWorkerStore):
 
             if user_has_pusher is not True:
                 # invalidate, since we the user might not have had a pusher before
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "add_pusher",
                     self._invalidate_cache_and_stream,
                     self.get_if_user_has_pusher,
@@ -270,7 +338,7 @@ class PusherStore(PusherWorkerStore):
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self._simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -279,7 +347,7 @@ class PusherStore(PusherWorkerStore):
             # it's possible for us to end up with duplicate rows for
             # (app_id, pushkey, user_id) at different stream_ids, but that
             # doesn't really matter.
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="deleted_pushers",
                 values={
@@ -291,82 +359,4 @@ class PusherStore(PusherWorkerStore):
             )
 
         with self._pushers_id_gen.get_next() as stream_id:
-            yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
-
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering(
-        self, app_id, pushkey, user_id, last_stream_ordering
-    ):
-        yield self._simple_update_one(
-            "pushers",
-            {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            {"last_stream_ordering": last_stream_ordering},
-            desc="update_pusher_last_stream_ordering",
-        )
-
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering_and_success(
-        self, app_id, pushkey, user_id, last_stream_ordering, last_success
-    ):
-        """Update the last stream ordering position we've processed up to for
-        the given pusher.
-
-        Args:
-            app_id (str)
-            pushkey (str)
-            last_stream_ordering (int)
-            last_success (int)
-
-        Returns:
-            Deferred[bool]: True if the pusher still exists; False if it has been deleted.
-        """
-        updated = yield self._simple_update(
-            table="pushers",
-            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            updatevalues={
-                "last_stream_ordering": last_stream_ordering,
-                "last_success": last_success,
-            },
-            desc="update_pusher_last_stream_ordering_and_success",
-        )
-
-        return bool(updated)
-
-    @defer.inlineCallbacks
-    def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
-        yield self._simple_update(
-            table="pushers",
-            keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
-            updatevalues={"failing_since": failing_since},
-            desc="update_pusher_failing_since",
-        )
-
-    @defer.inlineCallbacks
-    def get_throttle_params_by_room(self, pusher_id):
-        res = yield self._simple_select_list(
-            "pusher_throttle",
-            {"pusher": pusher_id},
-            ["room_id", "last_sent_ts", "throttle_ms"],
-            desc="get_throttle_params_by_room",
-        )
-
-        params_by_room = {}
-        for row in res:
-            params_by_room[row["room_id"]] = {
-                "last_sent_ts": row["last_sent_ts"],
-                "throttle_ms": row["throttle_ms"],
-            }
-
-        return params_by_room
-
-    @defer.inlineCallbacks
-    def set_throttle_params(self, pusher_id, room_id, params):
-        # no need to lock because `pusher_throttle` has a primary key on
-        # (pusher, room_id) so _simple_upsert will retry
-        yield self._simple_upsert(
-            "pusher_throttle",
-            {"pusher": pusher_id, "room_id": room_id},
-            params,
-            desc="set_throttle_params",
-            lock=False,
-        )
+            yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
diff --git a/synapse/storage/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 290ddb30e8..cebdcd409f 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -21,12 +21,12 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-from ._base import SQLBaseStore
-from .util.id_generators import StreamIdGenerator
-
 logger = logging.getLogger(__name__)
 
 
@@ -39,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -58,11 +58,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
-        return set(r["user_id"] for r in receipts)
+        return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="receipts_linearized",
             keyvalues={"room_id": room_id, "receipt_type": receipt_type},
             retcols=("user_id", "event_id"),
@@ -71,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3)
     def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
                 "room_id": room_id,
@@ -85,7 +85,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=2)
     def get_receipts_for_user(self, user_id, receipt_type):
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
             retcols=("room_id", "event_id"),
@@ -109,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return txn.fetchall()
 
-        rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
+        rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
         return {
             row[0]: {
                 "event_id": row[1],
@@ -188,11 +188,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
                 txn.execute(sql, (room_id, to_key))
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             return rows
 
-        rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
+        rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
             return []
@@ -217,28 +217,32 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         def f(txn):
             if from_key:
-                sql = (
-                    "SELECT * FROM receipts_linearized WHERE"
-                    " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
-                ) % (",".join(["?"] * len(room_ids)))
-                args = list(room_ids)
-                args.extend([from_key, to_key])
+                sql = """
+                    SELECT * FROM receipts_linearized WHERE
+                    stream_id > ? AND stream_id <= ? AND
+                """
+                clause, args = make_in_list_sql_clause(
+                    self.database_engine, "room_id", room_ids
+                )
 
-                txn.execute(sql, args)
+                txn.execute(sql + clause, [from_key, to_key] + list(args))
             else:
-                sql = (
-                    "SELECT * FROM receipts_linearized WHERE"
-                    " room_id IN (%s) AND stream_id <= ?"
-                ) % (",".join(["?"] * len(room_ids)))
+                sql = """
+                    SELECT * FROM receipts_linearized WHERE
+                    stream_id <= ? AND
+                """
 
-                args = list(room_ids)
-                args.append(to_key)
+                clause, args = make_in_list_sql_clause(
+                    self.database_engine, "room_id", room_ids
+                )
 
-                txn.execute(sql, args)
+                txn.execute(sql + clause, [to_key] + list(args))
 
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
+        txn_results = yield self.db.runInteraction(
+            "_get_linearized_receipts_for_rooms", f
+        )
 
         results = {}
         for row in txn_results:
@@ -279,9 +283,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 args.append(limit)
             txn.execute(sql, args)
 
-            return (r[0:5] + (json.loads(r[5]),) for r in txn)
+            return [r[0:5] + (json.loads(r[5]),) for r in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_receipts", get_all_updated_receipts_txn
         )
 
@@ -312,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
 
 class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = StreamIdGenerator(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(db_conn, hs)
+        super(ReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
@@ -334,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
         """
-        res = self._simple_select_one_txn(
+        res = self.db.simple_select_one_txn(
             txn,
             table="events",
             retcols=["stream_ordering", "received_ts"],
@@ -387,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             (user_id, room_id, receipt_type),
         )
 
-        self._simple_delete_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="receipts_linearized",
             keyvalues={
@@ -395,19 +399,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "receipt_type": receipt_type,
                 "user_id": user_id,
             },
-        )
-
-        self._simple_insert_txn(
-            txn,
-            table="receipts_linearized",
             values={
                 "stream_id": stream_id,
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
                 "event_id": event_id,
                 "data": json.dumps(data),
             },
+            # receipts_linearized has a unique constraint on
+            # (user_id, room_id, receipt_type), so no need to lock
+            lock=False,
         )
 
         if receipt_type == "m.read" and stream_ordering is not None:
@@ -433,26 +432,32 @@ class ReceiptsStore(ReceiptsWorkerStore):
             # we need to points in graph -> linearized form.
             # TODO: Make this better.
             def graph_to_linear(txn):
-                query = (
-                    "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
-                    " SELECT max(stream_ordering) WHERE event_id IN (%s)"
-                    ")"
-                ) % (",".join(["?"] * len(event_ids)))
+                clause, args = make_in_list_sql_clause(
+                    self.database_engine, "event_id", event_ids
+                )
+
+                sql = """
+                    SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+                        SELECT max(stream_ordering) WHERE %s
+                    )
+                """ % (
+                    clause,
+                )
 
-                txn.execute(query, [room_id] + event_ids)
+                txn.execute(sql, [room_id] + list(args))
                 rows = txn.fetchall()
                 if rows:
                     return rows[0][0]
                 else:
                     raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
 
-            linearized_event_id = yield self.runInteraction(
+            linearized_event_id = yield self.db.runInteraction(
                 "insert_receipt_conv", graph_to_linear
             )
 
         stream_id_manager = self._receipts_id_gen.get_next()
         with stream_id_manager as stream_id:
-            event_ts = yield self.runInteraction(
+            event_ts = yield self.db.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
                 room_id,
@@ -481,7 +486,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
         return stream_id, max_persisted_id
 
     def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
@@ -507,7 +512,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
         )
 
-        self._simple_delete_txn(
+        self.db.simple_delete_txn(
             txn,
             table="receipts_graph",
             keyvalues={
@@ -516,7 +521,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "user_id": user_id,
             },
         )
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="receipts_graph",
             values={
diff --git a/synapse/storage/registration.py b/synapse/storage/data_stores/main/registration.py
index 109052fa41..9768981891 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -17,17 +17,18 @@
 
 import logging
 import re
+from typing import Optional
 
 from six import iterkeys
-from six.moves import range
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, ThreepidValidationError
+from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -37,26 +38,28 @@ logger = logging.getLogger(__name__)
 
 
 class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
 
     @cached()
     def get_user_by_id(self, user_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
             retcols=[
                 "name",
                 "password_hash",
                 "is_guest",
+                "admin",
                 "consent_version",
                 "consent_server_notice_sent",
                 "appservice_id",
                 "creation_ts",
                 "user_type",
+                "deactivated",
             ],
             allow_none=True,
             desc="get_user_by_id",
@@ -94,7 +97,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 including the keys `name`, `is_guest`, `device_id`, `token_id`,
                 `valid_until_ms`.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
         )
 
@@ -109,7 +112,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 otherwise int representation of the timestamp (as a number of
                 milliseconds since epoch).
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
@@ -137,7 +140,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def set_account_validity_for_user_txn(txn):
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn=txn,
                 table="account_validity",
                 keyvalues={"user_id": user_id},
@@ -151,7 +154,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_expiration_ts_for_user, (user_id,)
             )
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "set_account_validity_for_user", set_account_validity_for_user_txn
         )
 
@@ -167,7 +170,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Raises:
             StoreError: The provided token is already set for another user.
         """
-        yield self._simple_update_one(
+        yield self.db.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"renewal_token": renewal_token},
@@ -184,7 +187,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             defer.Deferred[str]: The ID of the user to which the token belongs.
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"renewal_token": renewal_token},
             retcol="user_id",
@@ -203,7 +206,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             defer.Deferred[str]: The renewal token associated with this user ID.
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="renewal_token",
@@ -229,9 +232,9 @@ class RegistrationWorkerStore(SQLBaseStore):
             )
             values = [False, now_ms, renew_at]
             txn.execute(sql, values)
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        res = yield self.runInteraction(
+        res = yield self.db.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
             self.clock.time_msec(),
@@ -250,7 +253,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             email_sent (bool): Flag which indicates whether a renewal email has been sent
                 to this user.
         """
-        yield self._simple_update_one(
+        yield self.db.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"email_sent": email_sent},
@@ -265,14 +268,13 @@ class RegistrationWorkerStore(SQLBaseStore):
         Args:
             user_id (str): ID of the user to remove from the account validity table.
         """
-        yield self._simple_delete_one(
+        yield self.db.simple_delete_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             desc="delete_account_validity_for_user",
         )
 
-    @defer.inlineCallbacks
-    def is_server_admin(self, user):
+    async def is_server_admin(self, user):
         """Determines if a user is an admin of this homeserver.
 
         Args:
@@ -281,7 +283,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns (bool):
             true iff the user is a server admin, false otherwise.
         """
-        res = yield self._simple_select_one_onecol(
+        res = await self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user.to_string()},
             retcol="admin",
@@ -289,7 +291,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="is_server_admin",
         )
 
-        return res if res else False
+        return bool(res) if res else False
 
     def set_server_admin(self, user, admin):
         """Sets whether a user is an admin of this homeserver.
@@ -299,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore):
             admin (bool): true iff the user is to be a server admin,
                 false otherwise.
         """
-        return self._simple_update_one(
-            table="users",
-            keyvalues={"name": user.to_string()},
-            updatevalues={"admin": 1 if admin else 0},
-            desc="set_server_admin",
-        )
+
+        def set_server_admin_txn(txn):
+            self.db.simple_update_one_txn(
+                txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_id, (user.to_string(),)
+            )
+
+        return self.db.runInteraction("set_server_admin", set_server_admin_txn)
 
     def _query_for_auth(self, txn, token):
         sql = (
@@ -316,7 +322,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, (token,))
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
         if rows:
             return rows[0]
 
@@ -332,10 +338,12 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if user 'user_type' is null or empty string
         """
-        res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
+        res = yield self.db.runInteraction(
+            "is_real_user", self.is_real_user_txn, user_id
+        )
         return res
 
-    @cachedInlineCallbacks()
+    @cached()
     def is_support_user(self, user_id):
         """Determines if the user is of type UserTypes.SUPPORT
 
@@ -345,13 +353,12 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if user is of type UserTypes.SUPPORT
         """
-        res = yield self.runInteraction(
+        return self.db.runInteraction(
             "is_support_user", self.is_support_user_txn, user_id
         )
-        return res
 
     def is_real_user_txn(self, txn, user_id):
-        res = self._simple_select_one_onecol_txn(
+        res = self.db.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -361,7 +368,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return res is None
 
     def is_support_user_txn(self, txn, user_id):
-        res = self._simple_select_one_onecol_txn(
+        res = self.db.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -376,13 +383,31 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def f(txn):
-            sql = (
-                "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
-            )
+            sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
             txn.execute(sql, (user_id,))
             return dict(txn)
 
-        return self.runInteraction("get_users_by_id_case_insensitive", f)
+        return self.db.runInteraction("get_users_by_id_case_insensitive", f)
+
+    async def get_user_by_external_id(
+        self, auth_provider: str, external_id: str
+    ) -> str:
+        """Look up a user by their external auth id
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+
+        Returns:
+            str|None: the mxid of the user, or None if they are not known
+        """
+        return await self.db.simple_select_one_onecol(
+            table="user_external_ids",
+            keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_user_by_external_id",
+        )
 
     @defer.inlineCallbacks
     def count_all_users(self):
@@ -390,12 +415,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if rows:
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.runInteraction("count_users", _count_users)
+        ret = yield self.db.runInteraction("count_users", _count_users)
         return ret
 
     def count_daily_user_type(self):
@@ -427,7 +452,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 results[row[0]] = row[1]
             return results
 
-        return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+        return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
 
     @defer.inlineCallbacks
     def count_nonbridged_users(self):
@@ -438,10 +463,10 @@ class RegistrationWorkerStore(SQLBaseStore):
                 WHERE appservice_id IS NULL
             """
             )
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_users", _count_users)
+        ret = yield self.db.runInteraction("count_users", _count_users)
         return ret
 
     @defer.inlineCallbacks
@@ -450,12 +475,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if rows:
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.runInteraction("count_real_users", _count_users)
+        ret = yield self.db.runInteraction("count_real_users", _count_users)
         return ret
 
     @defer.inlineCallbacks
@@ -463,49 +488,45 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
         Gets the localpart of the next generated user ID.
 
-        Generated user IDs are integers, and we aim for them to be as small as
-        we can. Unfortunately, it's possible some of them are already taken by
-        existing users, and there may be gaps in the already taken range. This
-        function returns the start of the first allocatable gap. This is to
-        avoid the case of ID 10000000 being pre-allocated, so us wasting the
-        first (and shortest) many generated user IDs.
+        Generated user IDs are integers, so we find the largest integer user ID
+        already taken and return that plus one.
         """
 
         def _find_next_generated_user_id(txn):
-            txn.execute("SELECT name FROM users")
+            # We bound between '@0' and '@a' to avoid pulling the entire table
+            # out.
+            txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
 
             regex = re.compile(r"^@(\d+):")
 
-            found = set()
+            max_found = 0
 
             for (user_id,) in txn:
                 match = regex.search(user_id)
                 if match:
-                    found.add(int(match.group(1)))
-            for i in range(len(found) + 1):
-                if i not in found:
-                    return i
+                    max_found = max(int(match.group(1)), max_found)
+
+            return max_found + 1
 
         return (
             (
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "find_next_generated_user_id", _find_next_generated_user_id
                 )
             )
         )
 
-    @defer.inlineCallbacks
-    def get_user_id_by_threepid(self, medium, address, require_verified=False):
+    async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
         """Returns user id from threepid
 
         Args:
-            medium (str): threepid medium e.g. email
-            address (str): threepid address e.g. me@example.com
+            medium: threepid medium e.g. email
+            address: threepid address e.g. me@example.com
 
         Returns:
-            Deferred[str|None]: user id or None if no user id/threepid mapping exists
+            The user ID or None if no user id/threepid mapping exists
         """
-        user_id = yield self.runInteraction(
+        user_id = await self.db.runInteraction(
             "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
         )
         return user_id
@@ -521,7 +542,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             str|None: user id or None if no user id/threepid mapping exists
         """
-        ret = self._simple_select_one_txn(
+        ret = self.db.simple_select_one_txn(
             txn,
             "user_threepids",
             {"medium": medium, "address": address},
@@ -534,7 +555,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
-        yield self._simple_upsert(
+        yield self.db.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@@ -542,7 +563,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_get_threepids(self, user_id):
-        ret = yield self._simple_select_list(
+        ret = yield self.db.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
             ["medium", "address", "validated_at", "added_at"],
@@ -551,9 +572,22 @@ class RegistrationWorkerStore(SQLBaseStore):
         return ret
 
     def user_delete_threepid(self, user_id, medium, address):
-        return self._simple_delete(
+        return self.db.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
+            desc="user_delete_threepid",
+        )
+
+    def user_delete_threepids(self, user_id: str):
+        """Delete all threepid this user has bound
+
+        Args:
+             user_id: The user id to delete all threepids of
+
+        """
+        return self.db.simple_delete(
+            "user_threepids",
+            keyvalues={"user_id": user_id},
             desc="user_delete_threepids",
         )
 
@@ -573,7 +607,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
         # We need to use an upsert, in case they user had already bound the
         # threepid
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -586,6 +620,26 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="add_user_bound_threepid",
         )
 
+    def user_get_bound_threepids(self, user_id):
+        """Get the threepids that a user has bound to an identity server through the homeserver
+        The homeserver remembers where binds to an identity server occurred. Using this
+        method can retrieve those threepids.
+
+        Args:
+            user_id (str): The ID of the user to retrieve threepids for
+
+        Returns:
+            Deferred[list[dict]]: List of dictionaries containing the following:
+                medium (str): The medium of the threepid (e.g "email")
+                address (str): The address of the threepid (e.g "bob@example.com")
+        """
+        return self.db.simple_select_list(
+            table="user_threepid_id_server",
+            keyvalues={"user_id": user_id},
+            retcols=["medium", "address"],
+            desc="user_get_bound_threepids",
+        )
+
     def remove_user_bound_threepid(self, user_id, medium, address, id_server):
         """The server proxied an unbind request to the given identity server on
         behalf of the given user, so we remove the mapping of threepid to
@@ -600,7 +654,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred
         """
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -623,7 +677,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[list[str]]: Resolves to a list of identity servers
         """
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             retcol="id_server",
@@ -641,7 +695,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             defer.Deferred(bool): The requested value.
         """
 
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="deactivated",
@@ -655,24 +709,37 @@ class RegistrationWorkerStore(SQLBaseStore):
         self, medium, client_secret, address=None, sid=None, validated=True
     ):
         """Gets a session_id and last_send_attempt (if available) for a
-        client_secret/medium/(address|session_id) combo
+        combination of validation metadata
 
         Args:
             medium (str|None): The medium of the 3PID
             address (str|None): The address of the 3PID
             sid (str|None): The ID of the validation session
-            client_secret (str|None): A unique string provided by the client to
-                help identify this validation attempt
+            client_secret (str): A unique string provided by the client to help identify this
+                validation attempt
             validated (bool|None): Whether sessions should be filtered by
                 whether they have been validated already or not. None to
                 perform no filtering
 
         Returns:
-            deferred {str, int}|None: A dict containing the
-                latest session_id and send_attempt count for this 3PID.
-                Otherwise None if there hasn't been a previous attempt
+            Deferred[dict|None]: A dict containing the following:
+                * address - address of the 3pid
+                * medium - medium of the 3pid
+                * client_secret - a secret provided by the client for this validation session
+                * session_id - ID of the validation session
+                * send_attempt - a number serving to dedupe send attempts for this session
+                * validated_at - timestamp of when this session was validated if so
+
+                Otherwise None if a validation session is not found
         """
-        keyvalues = {"medium": medium, "client_secret": client_secret}
+        if not client_secret:
+            raise SynapseError(
+                400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM
+            )
+
+        keyvalues = {"client_secret": client_secret}
+        if medium:
+            keyvalues["medium"] = medium
         if address:
             keyvalues["address"] = address
         if sid:
@@ -695,13 +762,13 @@ class RegistrationWorkerStore(SQLBaseStore):
             sql += " LIMIT 1"
 
             txn.execute(sql, list(keyvalues.values()))
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return None
 
             return rows[0]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_threepid_validation_session", get_threepid_validation_session_txn
         )
 
@@ -715,70 +782,56 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def delete_threepid_session_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="threepid_validation_token",
                 keyvalues={"session_id": session_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_threepid_session", delete_threepid_session_txn
         )
 
 
-class RegistrationStore(
-    RegistrationWorkerStore, background_updates.BackgroundUpdateStore
-):
-    def __init__(self, db_conn, hs):
-        super(RegistrationStore, self).__init__(db_conn, hs)
+class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
+        self.config = hs.config
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "access_tokens_device_index",
             index_name="access_tokens_device_id",
             table="access_tokens",
             columns=["user_id", "device_id"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "users_creation_ts",
             index_name="users_creation_ts",
             table="users",
             columns=["creation_ts"],
         )
 
-        self._account_validity = hs.config.account_validity
-
         # we no longer use refresh tokens, but it's possible that some people
         # might have a background update queued to build this index. Just
         # clear the background update.
-        self.register_noop_background_update("refresh_tokens_device_index")
+        self.db.updates.register_noop_background_update("refresh_tokens_device_index")
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_threepids_grandfather", self._bg_user_threepids_grandfather
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
-        # Create a background job for culling expired 3PID validity tokens
-        def start_cull():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "cull_expired_threepid_validation_tokens",
-                self.cull_expired_threepid_validation_tokens,
-            )
-
-        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
     @defer.inlineCallbacks
     def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@@ -808,10 +861,10 @@ class RegistrationStore(
                 (last_user, batch_size),
             )
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             if not rows:
-                return True
+                return True, 0
 
             rows_processed_nb = 0
 
@@ -822,23 +875,79 @@ class RegistrationStore(
 
             logger.info("Marked %d rows as deactivated", rows_processed_nb)
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
             )
 
             if batch_size > len(rows):
-                return True
+                return True, len(rows)
             else:
-                return False
+                return False, len(rows)
 
-        end = yield self.runInteraction(
+        end, nb_processed = yield self.db.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
-            yield self._end_background_update("users_set_deactivated_flag")
+            yield self.db.updates._end_background_update("users_set_deactivated_flag")
 
-        return batch_size
+        return nb_processed
+
+    @defer.inlineCallbacks
+    def _bg_user_threepids_grandfather(self, progress, batch_size):
+        """We now track which identity servers a user binds their 3PID to, so
+        we need to handle the case of existing bindings where we didn't track
+        this.
+
+        We do this by grandfathering in existing user threepids assuming that
+        they used one of the server configured trusted identity servers.
+        """
+        id_servers = set(self.config.trusted_third_party_id_servers)
+
+        def _bg_user_threepids_grandfather_txn(txn):
+            sql = """
+                INSERT INTO user_threepid_id_server
+                    (user_id, medium, address, id_server)
+                SELECT user_id, medium, address, ?
+                FROM user_threepids
+            """
+
+            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+        if id_servers:
+            yield self.db.runInteraction(
+                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
+            )
+
+        yield self.db.updates._end_background_update("user_threepids_grandfather")
+
+        return 1
+
+
+class RegistrationStore(RegistrationBackgroundUpdateStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationStore, self).__init__(database, db_conn, hs)
+
+        self._account_validity = hs.config.account_validity
+
+        if self._account_validity.enabled:
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "account_validity_set_expiration_dates",
+                self._set_expiration_date_when_missing,
+            )
+
+        # Create a background job for culling expired 3PID validity tokens
+        def start_cull():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "cull_expired_threepid_validation_tokens",
+                self.cull_expired_threepid_validation_tokens,
+            )
+
+        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
 
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -856,7 +965,7 @@ class RegistrationStore(
         """
         next_id = self._access_tokens_id_gen.get_next()
 
-        yield self._simple_insert(
+        yield self.db.simple_insert(
             "access_tokens",
             {
                 "id": next_id,
@@ -883,7 +992,7 @@ class RegistrationStore(
 
         Args:
             user_id (str): The desired user ID to register.
-            password_hash (str): Optional. The password hash for this user.
+            password_hash (str|None): Optional. The password hash for this user.
             was_guest (bool): Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
             make_guest (boolean): True if the the new user should be guest,
@@ -897,8 +1006,11 @@ class RegistrationStore(
 
         Raises:
             StoreError if the user_id could not be registered.
+
+        Returns:
+            Deferred
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "register_user",
             self._register_user,
             user_id,
@@ -932,7 +1044,7 @@ class RegistrationStore(
                 # Ensure that the guest user actually exists
                 # ``allow_none=False`` makes this raise an exception
                 # if the row isn't in the database.
-                self._simple_select_one_txn(
+                self.db.simple_select_one_txn(
                     txn,
                     "users",
                     keyvalues={"name": user_id, "is_guest": 1},
@@ -940,7 +1052,7 @@ class RegistrationStore(
                     allow_none=False,
                 )
 
-                self._simple_update_one_txn(
+                self.db.simple_update_one_txn(
                     txn,
                     "users",
                     keyvalues={"name": user_id, "is_guest": 1},
@@ -954,7 +1066,7 @@ class RegistrationStore(
                     },
                 )
             else:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     "users",
                     values={
@@ -999,6 +1111,26 @@ class RegistrationStore(
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
+    def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> Deferred:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        return self.db.simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     def user_set_password_hash(self, user_id, password_hash):
         """
         NB. This does *not* evict any cache because the one use for this
@@ -1007,12 +1139,14 @@ class RegistrationStore(
         """
 
         def user_set_password_hash_txn(txn):
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn, "users", {"name": user_id}, {"password_hash": password_hash}
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
+        return self.db.runInteraction(
+            "user_set_password_hash", user_set_password_hash_txn
+        )
 
     def user_set_consent_version(self, user_id, consent_version):
         """Updates the user table to record privacy policy consent
@@ -1027,7 +1161,7 @@ class RegistrationStore(
         """
 
         def f(txn):
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="users",
                 keyvalues={"name": user_id},
@@ -1035,7 +1169,7 @@ class RegistrationStore(
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_consent_version", f)
+        return self.db.runInteraction("user_set_consent_version", f)
 
     def user_set_consent_server_notice_sent(self, user_id, consent_version):
         """Updates the user table to record that we have sent the user a server
@@ -1051,7 +1185,7 @@ class RegistrationStore(
         """
 
         def f(txn):
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="users",
                 keyvalues={"name": user_id},
@@ -1059,7 +1193,7 @@ class RegistrationStore(
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_consent_server_notice_sent", f)
+        return self.db.runInteraction("user_set_consent_server_notice_sent", f)
 
     def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
         """
@@ -1105,11 +1239,11 @@ class RegistrationStore(
 
             return tokens_and_devices
 
-        return self.runInteraction("user_delete_access_tokens", f)
+        return self.db.runInteraction("user_delete_access_tokens", f)
 
     def delete_access_token(self, access_token):
         def f(txn):
-            self._simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
             )
 
@@ -1117,11 +1251,11 @@ class RegistrationStore(
                 txn, self.get_user_by_access_token, (access_token,)
             )
 
-        return self.runInteraction("delete_access_token", f)
+        return self.db.runInteraction("delete_access_token", f)
 
     @cachedInlineCallbacks()
     def is_guest(self, user_id):
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="is_guest",
@@ -1136,7 +1270,7 @@ class RegistrationStore(
         Adds a user to the table of users who need to be parted from all the rooms they're
         in
         """
-        return self._simple_insert(
+        return self.db.simple_insert(
             "users_pending_deactivation",
             values={"user_id": user_id},
             desc="add_user_pending_deactivation",
@@ -1149,7 +1283,7 @@ class RegistrationStore(
         """
         # XXX: This should be simple_delete_one but we failed to put a unique index on
         # the table, so somehow duplicate entries have ended up in it.
-        return self._simple_delete(
+        return self.db.simple_delete(
             "users_pending_deactivation",
             keyvalues={"user_id": user_id},
             desc="del_user_pending_deactivation",
@@ -1160,7 +1294,7 @@ class RegistrationStore(
         Gets one user from the table of users waiting to be parted from all the rooms
         they're in.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             "users_pending_deactivation",
             keyvalues={},
             retcol="user_id",
@@ -1168,36 +1302,6 @@ class RegistrationStore(
             desc="get_users_pending_deactivation",
         )
 
-    @defer.inlineCallbacks
-    def _bg_user_threepids_grandfather(self, progress, batch_size):
-        """We now track which identity servers a user binds their 3PID to, so
-        we need to handle the case of existing bindings where we didn't track
-        this.
-
-        We do this by grandfathering in existing user threepids assuming that
-        they used one of the server configured trusted identity servers.
-        """
-        id_servers = set(self.config.trusted_third_party_id_servers)
-
-        def _bg_user_threepids_grandfather_txn(txn):
-            sql = """
-                INSERT INTO user_threepid_id_server
-                    (user_id, medium, address, id_server)
-                SELECT user_id, medium, address, ?
-                FROM user_threepids
-            """
-
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
-
-        if id_servers:
-            yield self.runInteraction(
-                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
-            )
-
-        yield self._end_background_update("user_threepids_grandfather")
-
-        return 1
-
     def validate_threepid_session(self, session_id, client_secret, token, current_ts):
         """Attempt to validate a threepid session using a token
 
@@ -1209,6 +1313,10 @@ class RegistrationStore(
             current_ts (int): The current unix time in milliseconds. Used for
                 checking token expiry status
 
+        Raises:
+            ThreepidValidationError: if a matching validation token was not found or has
+                expired
+
         Returns:
             deferred str|None: A str representing a link to redirect the user
             to if there is one.
@@ -1216,7 +1324,7 @@ class RegistrationStore(
 
         # Insert everything into a transaction in order to run atomically
         def validate_threepid_session_txn(txn):
-            row = self._simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1234,7 +1342,7 @@ class RegistrationStore(
                     400, "This client_secret does not match the provided session_id"
                 )
 
-            row = self._simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="threepid_validation_token",
                 keyvalues={"session_id": session_id, "token": token},
@@ -1259,7 +1367,7 @@ class RegistrationStore(
                 )
 
             # Looks good. Validate the session
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1269,7 +1377,7 @@ class RegistrationStore(
             return next_link
 
         # Return next_link if it exists
-        return self.runInteraction(
+        return self.db.runInteraction(
             "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
@@ -1302,7 +1410,7 @@ class RegistrationStore(
         if validated_at:
             insertion_values["validated_at"] = validated_at
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="threepid_validation_session",
             keyvalues={"session_id": session_id},
             values={"last_send_attempt": send_attempt},
@@ -1340,7 +1448,7 @@ class RegistrationStore(
 
         def start_or_continue_validation_session_txn(txn):
             # Create or update a validation session
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1353,7 +1461,7 @@ class RegistrationStore(
             )
 
             # Create a new validation token with this session ID
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="threepid_validation_token",
                 values={
@@ -1364,7 +1472,7 @@ class RegistrationStore(
                 },
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "start_or_continue_validation_session",
             start_or_continue_validation_session_txn,
         )
@@ -1379,14 +1487,30 @@ class RegistrationStore(
             """
             return txn.execute(sql, (ts,))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
             self.clock.time_msec(),
         )
 
+    @defer.inlineCallbacks
+    def set_user_deactivated_status(self, user_id, deactivated):
+        """Set the `deactivated` property for the provided user to the provided value.
+
+        Args:
+            user_id (str): The ID of the user to set the status for.
+            deactivated (bool): The value to set for `deactivated`.
+        """
+
+        yield self.db.runInteraction(
+            "set_user_deactivated_status",
+            self.set_user_deactivated_status_txn,
+            user_id,
+            deactivated,
+        )
+
     def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self._simple_update_one_txn(
+        self.db.simple_update_one_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -1397,17 +1521,57 @@ class RegistrationStore(
         )
 
     @defer.inlineCallbacks
-    def set_user_deactivated_status(self, user_id, deactivated):
-        """Set the `deactivated` property for the provided user to the provided value.
+    def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database, filtering out deactivated users.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+            )
+            txn.execute(sql, [])
+
+            res = self.db.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(
+                        txn, user["name"], use_delta=True
+                    )
+
+        yield self.db.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+        """Sets an expiration date to the account with the given user ID.
 
         Args:
-            user_id (str): The ID of the user to set the status for.
-            deactivated (bool): The value to set for `deactivated`.
+             user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period - d ; now + period] range, d being a
+                delta equal to 10% of the validity period.
         """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
 
-        yield self.runInteraction(
-            "set_user_deactivated_status",
-            self.set_user_deactivated_status_txn,
-            user_id,
-            deactivated,
+        if use_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts,
+            )
+
+        self.db.simple_upsert_txn(
+            txn,
+            "account_validity",
+            keyvalues={"user_id": user_id},
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
         )
diff --git a/synapse/storage/rejections.py b/synapse/storage/data_stores/main/rejections.py
index f4c1c2a457..27e5a2084a 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -15,25 +15,14 @@
 
 import logging
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
 
 class RejectionsStore(SQLBaseStore):
-    def _store_rejections_txn(self, txn, event_id, reason):
-        self._simple_insert_txn(
-            txn,
-            table="rejections",
-            values={
-                "event_id": event_id,
-                "reason": reason,
-                "last_check": self._clock.time_msec(),
-            },
-        )
-
     def get_rejection_reason(self, event_id):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
new file mode 100644
index 0000000000..7d477f8d01
--- /dev/null
+++ b/synapse/storage/data_stores/main/relations.py
@@ -0,0 +1,327 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from synapse.api.constants import RelationTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.stream import generate_pagination_where_clause
+from synapse.storage.relations import (
+    AggregationPaginationToken,
+    PaginationChunk,
+    RelationPaginationToken,
+)
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+    @cached(tree=True)
+    def get_relations_for_event(
+        self,
+        event_id,
+        relation_type=None,
+        event_type=None,
+        aggregation_key=None,
+        limit=5,
+        direction="b",
+        from_token=None,
+        to_token=None,
+    ):
+        """Get a list of relations for an event, ordered by topological ordering.
+
+        Args:
+            event_id (str): Fetch events that relate to this event ID.
+            relation_type (str|None): Only fetch events with this relation
+                type, if given.
+            event_type (str|None): Only fetch events with this event type, if
+                given.
+            aggregation_key (str|None): Only fetch events with this aggregation
+                key, if given.
+            limit (int): Only fetch the most recent `limit` events.
+            direction (str): Whether to fetch the most recent first (`"b"`) or
+                the oldest first (`"f"`).
+            from_token (RelationPaginationToken|None): Fetch rows from the given
+                token, or from the start if None.
+            to_token (RelationPaginationToken|None): Fetch rows up to the given
+                token, or up to the end if None.
+
+        Returns:
+            Deferred[PaginationChunk]: List of event IDs that match relations
+            requested. The rows are of the form `{"event_id": "..."}`.
+        """
+
+        where_clause = ["relates_to_id = ?"]
+        where_args = [event_id]
+
+        if relation_type is not None:
+            where_clause.append("relation_type = ?")
+            where_args.append(relation_type)
+
+        if event_type is not None:
+            where_clause.append("type = ?")
+            where_args.append(event_type)
+
+        if aggregation_key:
+            where_clause.append("aggregation_key = ?")
+            where_args.append(aggregation_key)
+
+        pagination_clause = generate_pagination_where_clause(
+            direction=direction,
+            column_names=("topological_ordering", "stream_ordering"),
+            from_token=attr.astuple(from_token) if from_token else None,
+            to_token=attr.astuple(to_token) if to_token else None,
+            engine=self.database_engine,
+        )
+
+        if pagination_clause:
+            where_clause.append(pagination_clause)
+
+        if direction == "b":
+            order = "DESC"
+        else:
+            order = "ASC"
+
+        sql = """
+            SELECT event_id, topological_ordering, stream_ordering
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE %s
+            ORDER BY topological_ordering %s, stream_ordering %s
+            LIMIT ?
+        """ % (
+            " AND ".join(where_clause),
+            order,
+            order,
+        )
+
+        def _get_recent_references_for_event_txn(txn):
+            txn.execute(sql, where_args + [limit + 1])
+
+            last_topo_id = None
+            last_stream_id = None
+            events = []
+            for row in txn:
+                events.append({"event_id": row[0]})
+                last_topo_id = row[1]
+                last_stream_id = row[2]
+
+            next_batch = None
+            if len(events) > limit and last_topo_id and last_stream_id:
+                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+            return PaginationChunk(
+                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+            )
+
+        return self.db.runInteraction(
+            "get_recent_references_for_event", _get_recent_references_for_event_txn
+        )
+
+    @cached(tree=True)
+    def get_aggregation_groups_for_event(
+        self,
+        event_id,
+        event_type=None,
+        limit=5,
+        direction="b",
+        from_token=None,
+        to_token=None,
+    ):
+        """Get a list of annotations on the event, grouped by event type and
+        aggregation key, sorted by count.
+
+        This is used e.g. to get the what and how many reactions have happend
+        on an event.
+
+        Args:
+            event_id (str): Fetch events that relate to this event ID.
+            event_type (str|None): Only fetch events with this event type, if
+                given.
+            limit (int): Only fetch the `limit` groups.
+            direction (str): Whether to fetch the highest count first (`"b"`) or
+                the lowest count first (`"f"`).
+            from_token (AggregationPaginationToken|None): Fetch rows from the
+                given token, or from the start if None.
+            to_token (AggregationPaginationToken|None): Fetch rows up to the
+                given token, or up to the end if None.
+
+
+        Returns:
+            Deferred[PaginationChunk]: List of groups of annotations that
+            match. Each row is a dict with `type`, `key` and `count` fields.
+        """
+
+        where_clause = ["relates_to_id = ?", "relation_type = ?"]
+        where_args = [event_id, RelationTypes.ANNOTATION]
+
+        if event_type:
+            where_clause.append("type = ?")
+            where_args.append(event_type)
+
+        having_clause = generate_pagination_where_clause(
+            direction=direction,
+            column_names=("COUNT(*)", "MAX(stream_ordering)"),
+            from_token=attr.astuple(from_token) if from_token else None,
+            to_token=attr.astuple(to_token) if to_token else None,
+            engine=self.database_engine,
+        )
+
+        if direction == "b":
+            order = "DESC"
+        else:
+            order = "ASC"
+
+        if having_clause:
+            having_clause = "HAVING " + having_clause
+        else:
+            having_clause = ""
+
+        sql = """
+            SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE {where_clause}
+            GROUP BY relation_type, type, aggregation_key
+            {having_clause}
+            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+            LIMIT ?
+        """.format(
+            where_clause=" AND ".join(where_clause),
+            order=order,
+            having_clause=having_clause,
+        )
+
+        def _get_aggregation_groups_for_event_txn(txn):
+            txn.execute(sql, where_args + [limit + 1])
+
+            next_batch = None
+            events = []
+            for row in txn:
+                events.append({"type": row[0], "key": row[1], "count": row[2]})
+                next_batch = AggregationPaginationToken(row[2], row[3])
+
+            if len(events) <= limit:
+                next_batch = None
+
+            return PaginationChunk(
+                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+            )
+
+        return self.db.runInteraction(
+            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+        )
+
+    @cachedInlineCallbacks()
+    def get_applicable_edit(self, event_id):
+        """Get the most recent edit (if any) that has happened for the given
+        event.
+
+        Correctly handles checking whether edits were allowed to happen.
+
+        Args:
+            event_id (str): The original event ID
+
+        Returns:
+            Deferred[EventBase|None]: Returns the most recent edit, if any.
+        """
+
+        # We only allow edits for `m.room.message` events that have the same sender
+        # and event type. We can't assert these things during regular event auth so
+        # we have to do the checks post hoc.
+
+        # Fetches latest edit that has the same type and sender as the
+        # original, and is an `m.room.message`.
+        sql = """
+            SELECT edit.event_id FROM events AS edit
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS original ON
+                original.event_id = relates_to_id
+                AND edit.type = original.type
+                AND edit.sender = original.sender
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND edit.type = 'm.room.message'
+            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+            LIMIT 1
+        """
+
+        def _get_applicable_edit_txn(txn):
+            txn.execute(sql, (event_id, RelationTypes.REPLACE))
+            row = txn.fetchone()
+            if row:
+                return row[0]
+
+        edit_id = yield self.db.runInteraction(
+            "get_applicable_edit", _get_applicable_edit_txn
+        )
+
+        if not edit_id:
+            return
+
+        edit_event = yield self.get_event(edit_id, allow_none=True)
+        return edit_event
+
+    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+        """Check if a user has already annotated an event with the same key
+        (e.g. already liked an event).
+
+        Args:
+            parent_id (str): The event being annotated
+            event_type (str): The event type of the annotation
+            aggregation_key (str): The aggregation key of the annotation
+            sender (str): The sender of the annotation
+
+        Returns:
+            Deferred[bool]
+        """
+
+        sql = """
+            SELECT 1 FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND type = ?
+                AND sender = ?
+                AND aggregation_key = ?
+            LIMIT 1;
+        """
+
+        def _get_if_user_has_annotated_event(txn):
+            txn.execute(
+                sql,
+                (
+                    parent_id,
+                    RelationTypes.ANNOTATION,
+                    event_type,
+                    sender,
+                    aggregation_key,
+                ),
+            )
+
+            return bool(txn.fetchone())
+
+        return self.db.runInteraction(
+            "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+        )
+
+
+class RelationsStore(RelationsWorkerStore):
+    pass
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
new file mode 100644
index 0000000000..46f643c6b9
--- /dev/null
+++ b/synapse/storage/data_stores/main/room.py
@@ -0,0 +1,1433 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import logging
+import re
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import StoreError
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.types import ThirdPartyInstanceID
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+OpsLevel = collections.namedtuple(
+    "OpsLevel", ("ban_level", "kick_level", "redact_level")
+)
+
+RatelimitOverride = collections.namedtuple(
+    "RatelimitOverride", ("messages_per_second", "burst_count")
+)
+
+
+class RoomSortOrder(Enum):
+    """
+    Enum to define the sorting method used when returning rooms with get_rooms_paginate
+
+    NAME = sort rooms alphabetically by name
+    JOINED_MEMBERS = sort rooms by membership size, highest to lowest
+    """
+
+    # ALPHABETICAL and SIZE are deprecated.
+    # ALPHABETICAL is the same as NAME.
+    ALPHABETICAL = "alphabetical"
+    # SIZE is the same as JOINED_MEMBERS.
+    SIZE = "size"
+    NAME = "name"
+    CANONICAL_ALIAS = "canonical_alias"
+    JOINED_MEMBERS = "joined_members"
+    JOINED_LOCAL_MEMBERS = "joined_local_members"
+    VERSION = "version"
+    CREATOR = "creator"
+    ENCRYPTION = "encryption"
+    FEDERATABLE = "federatable"
+    PUBLIC = "public"
+    JOIN_RULES = "join_rules"
+    GUEST_ACCESS = "guest_access"
+    HISTORY_VISIBILITY = "history_visibility"
+    STATE_EVENTS = "state_events"
+
+
+class RoomWorkerStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+
+        self.config = hs.config
+
+    def get_room(self, room_id):
+        """Retrieve a room.
+
+        Args:
+            room_id (str): The ID of the room to retrieve.
+        Returns:
+            A dict containing the room information, or None if the room is unknown.
+        """
+        return self.db.simple_select_one(
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcols=("room_id", "is_public", "creator"),
+            desc="get_room",
+            allow_none=True,
+        )
+
+    def get_room_with_stats(self, room_id: str):
+        """Retrieve room with statistics.
+
+        Args:
+            room_id: The ID of the room to retrieve.
+        Returns:
+            A dict containing the room information, or None if the room is unknown.
+        """
+
+        def get_room_with_stats_txn(txn, room_id):
+            sql = """
+                SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
+                  curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
+                  rooms.creator, state.encryption, state.is_federatable AS federatable,
+                  rooms.is_public AS public, state.join_rules, state.guest_access,
+                  state.history_visibility, curr.current_state_events AS state_events
+                FROM rooms
+                LEFT JOIN room_stats_state state USING (room_id)
+                LEFT JOIN room_stats_current curr USING (room_id)
+                WHERE room_id = ?
+                """
+            txn.execute(sql, [room_id])
+            res = self.db.cursor_to_dict(txn)[0]
+            res["federatable"] = bool(res["federatable"])
+            res["public"] = bool(res["public"])
+            return res
+
+        return self.db.runInteraction(
+            "get_room_with_stats", get_room_with_stats_txn, room_id
+        )
+
+    def get_public_room_ids(self):
+        return self.db.simple_select_onecol(
+            table="rooms",
+            keyvalues={"is_public": True},
+            retcol="room_id",
+            desc="get_public_room_ids",
+        )
+
+    def count_public_rooms(self, network_tuple, ignore_non_federatable):
+        """Counts the number of public rooms as tracked in the room_stats_current
+        and room_stats_state table.
+
+        Args:
+            network_tuple (ThirdPartyInstanceID|None)
+            ignore_non_federatable (bool): If true filters out non-federatable rooms
+        """
+
+        def _count_public_rooms_txn(txn):
+            query_args = []
+
+            if network_tuple:
+                if network_tuple.appservice_id:
+                    published_sql = """
+                        SELECT room_id from appservice_room_list
+                        WHERE appservice_id = ? AND network_id = ?
+                    """
+                    query_args.append(network_tuple.appservice_id)
+                    query_args.append(network_tuple.network_id)
+                else:
+                    published_sql = """
+                        SELECT room_id FROM rooms WHERE is_public
+                    """
+            else:
+                published_sql = """
+                    SELECT room_id FROM rooms WHERE is_public
+                    UNION SELECT room_id from appservice_room_list
+            """
+
+            sql = """
+                SELECT
+                    COALESCE(COUNT(*), 0)
+                FROM (
+                    %(published_sql)s
+                ) published
+                INNER JOIN room_stats_state USING (room_id)
+                INNER JOIN room_stats_current USING (room_id)
+                WHERE
+                    (
+                        join_rules = 'public' OR history_visibility = 'world_readable'
+                    )
+                    AND joined_members > 0
+            """ % {
+                "published_sql": published_sql
+            }
+
+            txn.execute(sql, query_args)
+            return txn.fetchone()[0]
+
+        return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
+
+    @defer.inlineCallbacks
+    def get_largest_public_rooms(
+        self,
+        network_tuple: Optional[ThirdPartyInstanceID],
+        search_filter: Optional[dict],
+        limit: Optional[int],
+        bounds: Optional[Tuple[int, str]],
+        forwards: bool,
+        ignore_non_federatable: bool = False,
+    ):
+        """Gets the largest public rooms (where largest is in terms of joined
+        members, as tracked in the statistics table).
+
+        Args:
+            network_tuple
+            search_filter
+            limit: Maxmimum number of rows to return, unlimited otherwise.
+            bounds: An uppoer or lower bound to apply to result set if given,
+                consists of a joined member count and room_id (these are
+                excluded from result set).
+            forwards: true iff going forwards, going backwards otherwise
+            ignore_non_federatable: If true filters out non-federatable rooms.
+
+        Returns:
+            Rooms in order: biggest number of joined users first.
+            We then arbitrarily use the room_id as a tie breaker.
+
+        """
+
+        where_clauses = []
+        query_args = []
+
+        if network_tuple:
+            if network_tuple.appservice_id:
+                published_sql = """
+                    SELECT room_id from appservice_room_list
+                    WHERE appservice_id = ? AND network_id = ?
+                """
+                query_args.append(network_tuple.appservice_id)
+                query_args.append(network_tuple.network_id)
+            else:
+                published_sql = """
+                    SELECT room_id FROM rooms WHERE is_public
+                """
+        else:
+            published_sql = """
+                SELECT room_id FROM rooms WHERE is_public
+                UNION SELECT room_id from appservice_room_list
+            """
+
+        # Work out the bounds if we're given them, these bounds look slightly
+        # odd, but are designed to help query planner use indices by pulling
+        # out a common bound.
+        if bounds:
+            last_joined_members, last_room_id = bounds
+            if forwards:
+                where_clauses.append(
+                    """
+                        joined_members <= ? AND (
+                            joined_members < ? OR room_id < ?
+                        )
+                    """
+                )
+            else:
+                where_clauses.append(
+                    """
+                        joined_members >= ? AND (
+                            joined_members > ? OR room_id > ?
+                        )
+                    """
+                )
+
+            query_args += [last_joined_members, last_joined_members, last_room_id]
+
+        if ignore_non_federatable:
+            where_clauses.append("is_federatable")
+
+        if search_filter and search_filter.get("generic_search_term", None):
+            search_term = "%" + search_filter["generic_search_term"] + "%"
+
+            where_clauses.append(
+                """
+                    (
+                        LOWER(name) LIKE ?
+                        OR LOWER(topic) LIKE ?
+                        OR LOWER(canonical_alias) LIKE ?
+                    )
+                """
+            )
+            query_args += [
+                search_term.lower(),
+                search_term.lower(),
+                search_term.lower(),
+            ]
+
+        where_clause = ""
+        if where_clauses:
+            where_clause = " AND " + " AND ".join(where_clauses)
+
+        sql = """
+            SELECT
+                room_id, name, topic, canonical_alias, joined_members,
+                avatar, history_visibility, joined_members, guest_access
+            FROM (
+                %(published_sql)s
+            ) published
+            INNER JOIN room_stats_state USING (room_id)
+            INNER JOIN room_stats_current USING (room_id)
+            WHERE
+                (
+                    join_rules = 'public' OR history_visibility = 'world_readable'
+                )
+                AND joined_members > 0
+                %(where_clause)s
+            ORDER BY joined_members %(dir)s, room_id %(dir)s
+        """ % {
+            "published_sql": published_sql,
+            "where_clause": where_clause,
+            "dir": "DESC" if forwards else "ASC",
+        }
+
+        if limit is not None:
+            query_args.append(limit)
+
+            sql += """
+                LIMIT ?
+            """
+
+        def _get_largest_public_rooms_txn(txn):
+            txn.execute(sql, query_args)
+
+            results = self.db.cursor_to_dict(txn)
+
+            if not forwards:
+                results.reverse()
+
+            return results
+
+        ret_val = yield self.db.runInteraction(
+            "get_largest_public_rooms", _get_largest_public_rooms_txn
+        )
+        defer.returnValue(ret_val)
+
+    @cached(max_entries=10000)
+    def is_room_blocked(self, room_id):
+        return self.db.simple_select_one_onecol(
+            table="blocked_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="1",
+            allow_none=True,
+            desc="is_room_blocked",
+        )
+
+    async def get_rooms_paginate(
+        self,
+        start: int,
+        limit: int,
+        order_by: RoomSortOrder,
+        reverse_order: bool,
+        search_term: Optional[str],
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Function to retrieve a paginated list of rooms as json.
+
+        Args:
+            start: offset in the list
+            limit: maximum amount of rooms to retrieve
+            order_by: the sort order of the returned list
+            reverse_order: whether to reverse the room list
+            search_term: a string to filter room names by
+        Returns:
+            A list of room dicts and an integer representing the total number of
+            rooms that exist given this query
+        """
+        # Filter room names by a string
+        where_statement = ""
+        if search_term:
+            where_statement = "WHERE state.name LIKE ?"
+
+            # Our postgres db driver converts ? -> %s in SQL strings as that's the
+            # placeholder for postgres.
+            # HOWEVER, if you put a % into your SQL then everything goes wibbly.
+            # To get around this, we're going to surround search_term with %'s
+            # before giving it to the database in python instead
+            search_term = "%" + search_term + "%"
+
+        # Set ordering
+        if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
+            # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS
+            order_by_column = "curr.joined_members"
+            order_by_asc = False
+        elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
+            # Deprecated in favour of RoomSortOrder.NAME
+            order_by_column = "state.name"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.NAME:
+            order_by_column = "state.name"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS:
+            order_by_column = "state.canonical_alias"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS:
+            order_by_column = "curr.joined_members"
+            order_by_asc = False
+        elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS:
+            order_by_column = "curr.local_users_in_room"
+            order_by_asc = False
+        elif RoomSortOrder(order_by) == RoomSortOrder.VERSION:
+            order_by_column = "rooms.room_version"
+            order_by_asc = False
+        elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR:
+            order_by_column = "rooms.creator"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION:
+            order_by_column = "state.encryption"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE:
+            order_by_column = "state.is_federatable"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC:
+            order_by_column = "rooms.is_public"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES:
+            order_by_column = "state.join_rules"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS:
+            order_by_column = "state.guest_access"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY:
+            order_by_column = "state.history_visibility"
+            order_by_asc = True
+        elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS:
+            order_by_column = "curr.current_state_events"
+            order_by_asc = False
+        else:
+            raise StoreError(
+                500, "Incorrect value for order_by provided: %s" % order_by
+            )
+
+        # Whether to return the list in reverse order
+        if reverse_order:
+            # Flip the boolean
+            order_by_asc = not order_by_asc
+
+        # Create one query for getting the limited number of events that the user asked
+        # for, and another query for getting the total number of events that could be
+        # returned. Thus allowing us to see if there are more events to paginate through
+        info_sql = """
+            SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
+              curr.local_users_in_room, rooms.room_version, rooms.creator,
+              state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
+              state.guest_access, state.history_visibility, curr.current_state_events
+            FROM room_stats_state state
+            INNER JOIN room_stats_current curr USING (room_id)
+            INNER JOIN rooms USING (room_id)
+            %s
+            ORDER BY %s %s
+            LIMIT ?
+            OFFSET ?
+        """ % (
+            where_statement,
+            order_by_column,
+            "ASC" if order_by_asc else "DESC",
+        )
+
+        # Use a nested SELECT statement as SQL can't count(*) with an OFFSET
+        count_sql = """
+            SELECT count(*) FROM (
+              SELECT room_id FROM room_stats_state state
+              %s
+            ) AS get_room_ids
+        """ % (
+            where_statement,
+        )
+
+        def _get_rooms_paginate_txn(txn):
+            # Execute the data query
+            sql_values = (limit, start)
+            if search_term:
+                # Add the search term into the WHERE clause
+                sql_values = (search_term,) + sql_values
+            txn.execute(info_sql, sql_values)
+
+            # Refactor room query data into a structured dictionary
+            rooms = []
+            for room in txn:
+                rooms.append(
+                    {
+                        "room_id": room[0],
+                        "name": room[1],
+                        "canonical_alias": room[2],
+                        "joined_members": room[3],
+                        "joined_local_members": room[4],
+                        "version": room[5],
+                        "creator": room[6],
+                        "encryption": room[7],
+                        "federatable": room[8],
+                        "public": room[9],
+                        "join_rules": room[10],
+                        "guest_access": room[11],
+                        "history_visibility": room[12],
+                        "state_events": room[13],
+                    }
+                )
+
+            # Execute the count query
+
+            # Add the search term into the WHERE clause if present
+            sql_values = (search_term,) if search_term else ()
+            txn.execute(count_sql, sql_values)
+
+            room_count = txn.fetchone()
+            return rooms, room_count[0]
+
+        return await self.db.runInteraction(
+            "get_rooms_paginate", _get_rooms_paginate_txn,
+        )
+
+    @cachedInlineCallbacks(max_entries=10000)
+    def get_ratelimit_for_user(self, user_id):
+        """Check if there are any overrides for ratelimiting for the given
+        user
+
+        Args:
+            user_id (str)
+
+        Returns:
+            RatelimitOverride if there is an override, else None. If the contents
+            of RatelimitOverride are None or 0 then ratelimitng has been
+            disabled for that user entirely.
+        """
+        row = yield self.db.simple_select_one(
+            table="ratelimit_override",
+            keyvalues={"user_id": user_id},
+            retcols=("messages_per_second", "burst_count"),
+            allow_none=True,
+            desc="get_ratelimit_for_user",
+        )
+
+        if row:
+            return RatelimitOverride(
+                messages_per_second=row["messages_per_second"],
+                burst_count=row["burst_count"],
+            )
+        else:
+            return None
+
+    @cachedInlineCallbacks()
+    def get_retention_policy_for_room(self, room_id):
+        """Get the retention policy for a given room.
+
+        If no retention policy has been found for this room, returns a policy defined
+        by the configured default policy (which has None as both the 'min_lifetime' and
+        the 'max_lifetime' if no default policy has been defined in the server's
+        configuration).
+
+        Args:
+            room_id (str): The ID of the room to get the retention policy of.
+
+        Returns:
+            dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+        """
+
+        def get_retention_policy_for_room_txn(txn):
+            txn.execute(
+                """
+                SELECT min_lifetime, max_lifetime FROM room_retention
+                INNER JOIN current_state_events USING (event_id, room_id)
+                WHERE room_id = ?;
+                """,
+                (room_id,),
+            )
+
+            return self.db.cursor_to_dict(txn)
+
+        ret = yield self.db.runInteraction(
+            "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+        )
+
+        # If we don't know this room ID, ret will be None, in this case return the default
+        # policy.
+        if not ret:
+            defer.returnValue(
+                {
+                    "min_lifetime": self.config.retention_default_min_lifetime,
+                    "max_lifetime": self.config.retention_default_max_lifetime,
+                }
+            )
+
+        row = ret[0]
+
+        # If one of the room's policy's attributes isn't defined, use the matching
+        # attribute from the default policy.
+        # The default values will be None if no default policy has been defined, or if one
+        # of the attributes is missing from the default policy.
+        if row["min_lifetime"] is None:
+            row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+        if row["max_lifetime"] is None:
+            row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+        defer.returnValue(row)
+
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+
+        return self.db.runInteraction(
+            "get_media_ids_in_room", _get_media_mxcs_in_room_txn
+        )
+
+    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+        """For a room loops through all events with media and quarantines
+        the associated media
+        """
+
+        logger.info("Quarantining media in room: %s", room_id)
+
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
+
+            # Now update all the tables to set the quarantined_by flag
+
+            txn.executemany(
+                """
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """,
+                ((quarantined_by, media_id) for media_id in local_mxcs),
+            )
+
+            txn.executemany(
+                """
+                    UPDATE remote_media_cache
+                    SET quarantined_by = ?
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
+                ),
+            )
+
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
+
+            return total_media_quarantined
+
+        return self.db.runInteraction(
+            "quarantine_media_in_room", _quarantine_media_in_room_txn
+        )
+
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            txn (cursor)
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        sql = """
+            SELECT stream_ordering, json FROM events
+            JOIN event_json USING (room_id, event_id)
+            WHERE room_id = ?
+                %(where_clause)s
+                AND contains_url = ? AND outlier = ?
+            ORDER BY stream_ordering DESC
+            LIMIT ?
+        """
+        txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100))
+
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while True:
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                event_json = json.loads(content_json)
+                content = event_json["content"]
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hs.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+            if next_token is None:
+                # We've gone through the whole room, so we're finished.
+                break
+
+            txn.execute(
+                sql % {"where_clause": "AND stream_ordering < ?"},
+                (room_id, next_token, True, False, 100),
+            )
+
+        return local_media_mxcs, remote_media_mxcs
+
+    def quarantine_media_by_id(
+        self, server_name: str, media_id: str, quarantined_by: str,
+    ):
+        """quarantines a single local or remote media id
+
+        Args:
+            server_name: The name of the server that holds this media
+            media_id: The ID of the media to be quarantined
+            quarantined_by: The user ID that initiated the quarantine request
+        """
+        logger.info("Quarantining media: %s/%s", server_name, media_id)
+        is_local = server_name == self.config.server_name
+
+        def _quarantine_media_by_id_txn(txn):
+            local_mxcs = [media_id] if is_local else []
+            remote_mxcs = [(server_name, media_id)] if not is_local else []
+
+            return self._quarantine_media_txn(
+                txn, local_mxcs, remote_mxcs, quarantined_by
+            )
+
+        return self.db.runInteraction(
+            "quarantine_media_by_user", _quarantine_media_by_id_txn
+        )
+
+    def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+        """quarantines all local media associated with a single user
+
+        Args:
+            user_id: The ID of the user to quarantine media of
+            quarantined_by: The ID of the user who made the quarantine request
+        """
+
+        def _quarantine_media_by_user_txn(txn):
+            local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
+            return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
+
+        return self.db.runInteraction(
+            "quarantine_media_by_user", _quarantine_media_by_user_txn
+        )
+
+    def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+        """Retrieves local media IDs by a given user
+
+        Args:
+            txn (cursor)
+            user_id: The ID of the user to retrieve media IDs of
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        # Local media
+        sql = """
+            SELECT media_id
+            FROM local_media_repository
+            WHERE user_id = ?
+            """
+        if filter_quarantined:
+            sql += "AND quarantined_by IS NULL"
+        txn.execute(sql, (user_id,))
+
+        local_media_ids = [row[0] for row in txn]
+
+        # TODO: Figure out all remote media a user has referenced in a message
+
+        return local_media_ids
+
+    def _quarantine_media_txn(
+        self,
+        txn,
+        local_mxcs: List[str],
+        remote_mxcs: List[Tuple[str, str]],
+        quarantined_by: str,
+    ) -> int:
+        """Quarantine local and remote media items
+
+        Args:
+            txn (cursor)
+            local_mxcs: A list of local mxc URLs
+            remote_mxcs: A list of (remote server, media id) tuples representing
+                remote mxc URLs
+            quarantined_by: The ID of the user who initiated the quarantine request
+        Returns:
+            The total number of media items quarantined
+        """
+        total_media_quarantined = 0
+
+        # Update all the tables to set the quarantined_by flag
+        txn.executemany(
+            """
+            UPDATE local_media_repository
+            SET quarantined_by = ?
+            WHERE media_id = ?
+        """,
+            ((quarantined_by, media_id) for media_id in local_mxcs),
+        )
+
+        txn.executemany(
+            """
+                UPDATE remote_media_cache
+                SET quarantined_by = ?
+                WHERE media_origin = ? AND media_id = ?
+            """,
+            ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
+        )
+
+        total_media_quarantined += len(local_mxcs)
+        total_media_quarantined += len(remote_mxcs)
+
+        return total_media_quarantined
+
+    def get_all_new_public_rooms(self, prev_id, current_id, limit):
+        def get_all_new_public_rooms(txn):
+            sql = """
+                SELECT stream_id, room_id, visibility, appservice_id, network_id
+                FROM public_room_list_stream
+                WHERE stream_id > ? AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (prev_id, current_id, limit))
+            return txn.fetchall()
+
+        if prev_id == current_id:
+            return defer.succeed([])
+
+        return self.db.runInteraction(
+            "get_all_new_public_rooms", get_all_new_public_rooms
+        )
+
+
+class RoomBackgroundUpdateStore(SQLBaseStore):
+    REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
+    ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+        self.config = hs.config
+
+        self.db.updates.register_background_update_handler(
+            "insert_room_retention", self._background_insert_retention,
+        )
+
+        self.db.updates.register_background_update_handler(
+            self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
+            self._remove_tombstoned_rooms_from_directory,
+        )
+
+        self.db.updates.register_background_update_handler(
+            self.ADD_ROOMS_ROOM_VERSION_COLUMN,
+            self._background_add_rooms_room_version_column,
+        )
+
+    @defer.inlineCallbacks
+    def _background_insert_retention(self, progress, batch_size):
+        """Retrieves a list of all rooms within a range and inserts an entry for each of
+        them into the room_retention table.
+        NULLs the property's columns if missing from the retention event in the room's
+        state (or NULLs all of them if there's no retention event in the room's state),
+        so that we fall back to the server's retention policy.
+        """
+
+        last_room = progress.get("room_id", "")
+
+        def _background_insert_retention_txn(txn):
+            txn.execute(
+                """
+                SELECT state.room_id, state.event_id, events.json
+                FROM current_state_events as state
+                LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+                WHERE state.room_id > ? AND state.type = '%s'
+                ORDER BY state.room_id ASC
+                LIMIT ?;
+                """
+                % EventTypes.Retention,
+                (last_room, batch_size),
+            )
+
+            rows = self.db.cursor_to_dict(txn)
+
+            if not rows:
+                return True
+
+            for row in rows:
+                if not row["json"]:
+                    retention_policy = {}
+                else:
+                    ev = json.loads(row["json"])
+                    retention_policy = json.dumps(ev["content"])
+
+                self.db.simple_insert_txn(
+                    txn=txn,
+                    table="room_retention",
+                    values={
+                        "room_id": row["room_id"],
+                        "event_id": row["event_id"],
+                        "min_lifetime": retention_policy.get("min_lifetime"),
+                        "max_lifetime": retention_policy.get("max_lifetime"),
+                    },
+                )
+
+            logger.info("Inserted %d rows into room_retention", len(rows))
+
+            self.db.updates._background_update_progress_txn(
+                txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+            )
+
+            if batch_size > len(rows):
+                return True
+            else:
+                return False
+
+        end = yield self.db.runInteraction(
+            "insert_room_retention", _background_insert_retention_txn,
+        )
+
+        if end:
+            yield self.db.updates._end_background_update("insert_room_retention")
+
+        defer.returnValue(batch_size)
+
+    async def _background_add_rooms_room_version_column(
+        self, progress: dict, batch_size: int
+    ):
+        """Background update to go and add room version inforamtion to `rooms`
+        table from `current_state_events` table.
+        """
+
+        last_room_id = progress.get("room_id", "")
+
+        def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+            sql = """
+                SELECT room_id, json FROM current_state_events
+                INNER JOIN event_json USING (room_id, event_id)
+                WHERE room_id > ? AND type = 'm.room.create' AND state_key = ''
+                ORDER BY room_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_room_id, batch_size))
+
+            updates = []
+            for room_id, event_json in txn:
+                event_dict = json.loads(event_json)
+                room_version_id = event_dict.get("content", {}).get(
+                    "room_version", RoomVersions.V1.identifier
+                )
+
+                creator = event_dict.get("content").get("creator")
+
+                updates.append((room_id, creator, room_version_id))
+
+            if not updates:
+                return True
+
+            new_last_room_id = ""
+            for room_id, creator, room_version_id in updates:
+                # We upsert here just in case we don't already have a row,
+                # mainly for paranoia as much badness would happen if we don't
+                # insert the row and then try and get the room version for the
+                # room.
+                self.db.simple_upsert_txn(
+                    txn,
+                    table="rooms",
+                    keyvalues={"room_id": room_id},
+                    values={"room_version": room_version_id},
+                    insertion_values={"is_public": False, "creator": creator},
+                )
+                new_last_room_id = room_id
+
+            self.db.updates._background_update_progress_txn(
+                txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
+            )
+
+            return False
+
+        end = await self.db.runInteraction(
+            "_background_add_rooms_room_version_column",
+            _background_add_rooms_room_version_column_txn,
+        )
+
+        if end:
+            await self.db.updates._end_background_update(
+                self.ADD_ROOMS_ROOM_VERSION_COLUMN
+            )
+
+        return batch_size
+
+    async def _remove_tombstoned_rooms_from_directory(
+        self, progress, batch_size
+    ) -> int:
+        """Removes any rooms with tombstone events from the room directory
+
+        Nowadays this is handled by the room upgrade handler, but we may have some
+        that got left behind
+        """
+
+        last_room = progress.get("room_id", "")
+
+        def _get_rooms(txn):
+            txn.execute(
+                """
+                SELECT room_id
+                FROM rooms r
+                INNER JOIN current_state_events cse USING (room_id)
+                WHERE room_id > ? AND r.is_public
+                AND cse.type = '%s' AND cse.state_key = ''
+                ORDER BY room_id ASC
+                LIMIT ?;
+                """
+                % EventTypes.Tombstone,
+                (last_room, batch_size),
+            )
+
+            return [row[0] for row in txn]
+
+        rooms = await self.db.runInteraction(
+            "get_tombstoned_directory_rooms", _get_rooms
+        )
+
+        if not rooms:
+            await self.db.updates._end_background_update(
+                self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
+            )
+            return 0
+
+        for room_id in rooms:
+            logger.info("Removing tombstoned room %s from the directory", room_id)
+            await self.set_room_is_public(room_id, False)
+
+        await self.db.updates._background_update_progress(
+            self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
+        )
+
+        return len(rooms)
+
+    @abstractmethod
+    def set_room_is_public(self, room_id, is_public):
+        # this will need to be implemented if a background update is performed with
+        # existing (tombstoned, public) rooms in the database.
+        #
+        # It's overridden by RoomStore for the synapse master.
+        raise NotImplementedError()
+
+
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
+
+        self.config = hs.config
+
+    async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion):
+        """Ensure that the room is stored in the table
+
+        Called when we join a room over federation, and overwrites any room version
+        currently in the table.
+        """
+        await self.db.simple_upsert(
+            desc="upsert_room_on_join",
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            values={"room_version": room_version.identifier},
+            insertion_values={"is_public": False, "creator": ""},
+            # rooms has a unique constraint on room_id, so no need to lock when doing an
+            # emulated upsert.
+            lock=False,
+        )
+
+    @defer.inlineCallbacks
+    def store_room(
+        self,
+        room_id: str,
+        room_creator_user_id: str,
+        is_public: bool,
+        room_version: RoomVersion,
+    ):
+        """Stores a room.
+
+        Args:
+            room_id: The desired room ID, can be None.
+            room_creator_user_id: The user ID of the room creator.
+            is_public: True to indicate that this room should appear in
+                public room lists.
+            room_version: The version of the room
+        Raises:
+            StoreError if the room could not be stored.
+        """
+        try:
+
+            def store_room_txn(txn, next_id):
+                self.db.simple_insert_txn(
+                    txn,
+                    "rooms",
+                    {
+                        "room_id": room_id,
+                        "creator": room_creator_user_id,
+                        "is_public": is_public,
+                        "room_version": room_version.identifier,
+                    },
+                )
+                if is_public:
+                    self.db.simple_insert_txn(
+                        txn,
+                        table="public_room_list_stream",
+                        values={
+                            "stream_id": next_id,
+                            "room_id": room_id,
+                            "visibility": is_public,
+                        },
+                    )
+
+            with self._public_room_id_gen.get_next() as next_id:
+                yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
+        except Exception as e:
+            logger.error("store_room with room_id=%s failed: %s", room_id, e)
+            raise StoreError(500, "Problem creating room.")
+
+    async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+        """
+        When we receive an invite over federation, store the version of the room if we
+        don't already know the room version.
+        """
+        await self.db.simple_upsert(
+            desc="maybe_store_room_on_invite",
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            values={},
+            insertion_values={
+                "room_version": room_version.identifier,
+                "is_public": False,
+                "creator": "",
+            },
+            # rooms has a unique constraint on room_id, so no need to lock when doing an
+            # emulated upsert.
+            lock=False,
+        )
+
+    @defer.inlineCallbacks
+    def set_room_is_public(self, room_id, is_public):
+        def set_room_is_public_txn(txn, next_id):
+            self.db.simple_update_one_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"is_public": is_public},
+            )
+
+            entries = self.db.simple_select_list_txn(
+                txn,
+                table="public_room_list_stream",
+                keyvalues={
+                    "room_id": room_id,
+                    "appservice_id": None,
+                    "network_id": None,
+                },
+                retcols=("stream_id", "visibility"),
+            )
+
+            entries.sort(key=lambda r: r["stream_id"])
+
+            add_to_stream = True
+            if entries:
+                add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+            if add_to_stream:
+                self.db.simple_insert_txn(
+                    txn,
+                    table="public_room_list_stream",
+                    values={
+                        "stream_id": next_id,
+                        "room_id": room_id,
+                        "visibility": is_public,
+                        "appservice_id": None,
+                        "network_id": None,
+                    },
+                )
+
+        with self._public_room_id_gen.get_next() as next_id:
+            yield self.db.runInteraction(
+                "set_room_is_public", set_room_is_public_txn, next_id
+            )
+        self.hs.get_notifier().on_new_replication_data()
+
+    @defer.inlineCallbacks
+    def set_room_is_public_appservice(
+        self, room_id, appservice_id, network_id, is_public
+    ):
+        """Edit the appservice/network specific public room list.
+
+        Each appservice can have a number of published room lists associated
+        with them, keyed off of an appservice defined `network_id`, which
+        basically represents a single instance of a bridge to a third party
+        network.
+
+        Args:
+            room_id (str)
+            appservice_id (str)
+            network_id (str)
+            is_public (bool): Whether to publish or unpublish the room from the
+                list.
+        """
+
+        def set_room_is_public_appservice_txn(txn, next_id):
+            if is_public:
+                try:
+                    self.db.simple_insert_txn(
+                        txn,
+                        table="appservice_room_list",
+                        values={
+                            "appservice_id": appservice_id,
+                            "network_id": network_id,
+                            "room_id": room_id,
+                        },
+                    )
+                except self.database_engine.module.IntegrityError:
+                    # We've already inserted, nothing to do.
+                    return
+            else:
+                self.db.simple_delete_txn(
+                    txn,
+                    table="appservice_room_list",
+                    keyvalues={
+                        "appservice_id": appservice_id,
+                        "network_id": network_id,
+                        "room_id": room_id,
+                    },
+                )
+
+            entries = self.db.simple_select_list_txn(
+                txn,
+                table="public_room_list_stream",
+                keyvalues={
+                    "room_id": room_id,
+                    "appservice_id": appservice_id,
+                    "network_id": network_id,
+                },
+                retcols=("stream_id", "visibility"),
+            )
+
+            entries.sort(key=lambda r: r["stream_id"])
+
+            add_to_stream = True
+            if entries:
+                add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+            if add_to_stream:
+                self.db.simple_insert_txn(
+                    txn,
+                    table="public_room_list_stream",
+                    values={
+                        "stream_id": next_id,
+                        "room_id": room_id,
+                        "visibility": is_public,
+                        "appservice_id": appservice_id,
+                        "network_id": network_id,
+                    },
+                )
+
+        with self._public_room_id_gen.get_next() as next_id:
+            yield self.db.runInteraction(
+                "set_room_is_public_appservice",
+                set_room_is_public_appservice_txn,
+                next_id,
+            )
+        self.hs.get_notifier().on_new_replication_data()
+
+    def get_room_count(self):
+        """Retrieve a list of all rooms
+        """
+
+        def f(txn):
+            sql = "SELECT count(*)  FROM rooms"
+            txn.execute(sql)
+            row = txn.fetchone()
+            return row[0] or 0
+
+        return self.db.runInteraction("get_rooms", f)
+
+    def add_event_report(
+        self, room_id, event_id, user_id, reason, content, received_ts
+    ):
+        next_id = self._event_reports_id_gen.get_next()
+        return self.db.simple_insert(
+            table="event_reports",
+            values={
+                "id": next_id,
+                "received_ts": received_ts,
+                "room_id": room_id,
+                "event_id": event_id,
+                "user_id": user_id,
+                "reason": reason,
+                "content": json.dumps(content),
+            },
+            desc="add_event_report",
+        )
+
+    def get_current_public_room_stream_id(self):
+        return self._public_room_id_gen.get_current_token()
+
+    @defer.inlineCallbacks
+    def block_room(self, room_id, user_id):
+        """Marks the room as blocked. Can be called multiple times.
+
+        Args:
+            room_id (str): Room to block
+            user_id (str): Who blocked it
+
+        Returns:
+            Deferred
+        """
+        yield self.db.simple_upsert(
+            table="blocked_rooms",
+            keyvalues={"room_id": room_id},
+            values={},
+            insertion_values={"user_id": user_id},
+            desc="block_room",
+        )
+        yield self.db.runInteraction(
+            "block_room_invalidation",
+            self._invalidate_cache_and_stream,
+            self.is_room_blocked,
+            (room_id,),
+        )
+
+    @defer.inlineCallbacks
+    def get_rooms_for_retention_period_in_range(
+        self, min_ms, max_ms, include_null=False
+    ):
+        """Retrieves all of the rooms within the given retention range.
+
+        Optionally includes the rooms which don't have a retention policy.
+
+        Args:
+            min_ms (int|None): Duration in milliseconds that define the lower limit of
+                the range to handle (exclusive). If None, doesn't set a lower limit.
+            max_ms (int|None): Duration in milliseconds that define the upper limit of
+                the range to handle (inclusive). If None, doesn't set an upper limit.
+            include_null (bool): Whether to include rooms which retention policy is NULL
+                in the returned set.
+
+        Returns:
+            dict[str, dict]: The rooms within this range, along with their retention
+                policy. The key is "room_id", and maps to a dict describing the retention
+                policy associated with this room ID. The keys for this nested dict are
+                "min_lifetime" (int|None), and "max_lifetime" (int|None).
+        """
+
+        def get_rooms_for_retention_period_in_range_txn(txn):
+            range_conditions = []
+            args = []
+
+            if min_ms is not None:
+                range_conditions.append("max_lifetime > ?")
+                args.append(min_ms)
+
+            if max_ms is not None:
+                range_conditions.append("max_lifetime <= ?")
+                args.append(max_ms)
+
+            # Do a first query which will retrieve the rooms that have a retention policy
+            # in their current state.
+            sql = """
+                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+                INNER JOIN current_state_events USING (event_id, room_id)
+                """
+
+            if len(range_conditions):
+                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+                if include_null:
+                    sql += " OR max_lifetime IS NULL"
+
+            txn.execute(sql, args)
+
+            rows = self.db.cursor_to_dict(txn)
+            rooms_dict = {}
+
+            for row in rows:
+                rooms_dict[row["room_id"]] = {
+                    "min_lifetime": row["min_lifetime"],
+                    "max_lifetime": row["max_lifetime"],
+                }
+
+            if include_null:
+                # If required, do a second query that retrieves all of the rooms we know
+                # of so we can handle rooms with no retention policy.
+                sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+                txn.execute(sql)
+
+                rows = self.db.cursor_to_dict(txn)
+
+                # If a room isn't already in the dict (i.e. it doesn't have a retention
+                # policy in its state), add it with a null policy.
+                for row in rows:
+                    if row["room_id"] not in rooms_dict:
+                        rooms_dict[row["room_id"]] = {
+                            "min_lifetime": None,
+                            "max_lifetime": None,
+                        }
+
+            return rooms_dict
+
+        rooms = yield self.db.runInteraction(
+            "get_rooms_for_retention_period_in_range",
+            get_rooms_for_retention_period_in_range_txn,
+        )
+
+        defer.returnValue(rooms)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
new file mode 100644
index 0000000000..137ebac833
--- /dev/null
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -0,0 +1,1138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Iterable, List, Set
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import (
+    LoggingTransaction,
+    SQLBaseStore,
+    make_in_list_sql_clause,
+)
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.engines import Sqlite3Engine
+from synapse.storage.roommember import (
+    GetRoomsForUserWithStreamOrdering,
+    MemberSummary,
+    ProfileInfo,
+    RoomsForUser,
+)
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+
+
+class RoomMemberWorkerStore(EventsWorkerStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+
+        # Is the current_state_events.membership up to date? Or is the
+        # background update still running?
+        self._current_state_events_membership_up_to_date = False
+
+        txn = LoggingTransaction(
+            db_conn.cursor(),
+            name="_check_safe_current_state_events_membership_updated",
+            database_engine=self.database_engine,
+        )
+        self._check_safe_current_state_events_membership_updated_txn(txn)
+        txn.close()
+
+        if self.hs.config.metrics_flags.known_servers:
+            self._known_servers_count = 1
+            self.hs.get_clock().looping_call(
+                run_as_background_process,
+                60 * 1000,
+                "_count_known_servers",
+                self._count_known_servers,
+            )
+            self.hs.get_clock().call_later(
+                1000,
+                run_as_background_process,
+                "_count_known_servers",
+                self._count_known_servers,
+            )
+            LaterGauge(
+                "synapse_federation_known_servers",
+                "",
+                [],
+                lambda: self._known_servers_count,
+            )
+
+    @defer.inlineCallbacks
+    def _count_known_servers(self):
+        """
+        Count the servers that this server knows about.
+
+        The statistic is stored on the class for the
+        `synapse_federation_known_servers` LaterGauge to collect.
+        """
+
+        def _transact(txn):
+            if isinstance(self.database_engine, Sqlite3Engine):
+                query = """
+                    SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
+                    FROM (
+                        SELECT rm.user_id as user_id, instr(rm.user_id, ':')
+                            AS pos FROM room_memberships as rm
+                        INNER JOIN current_state_events as c ON rm.event_id = c.event_id
+                        WHERE c.type = 'm.room.member'
+                    ) as out
+                """
+            else:
+                query = """
+                    SELECT COUNT(DISTINCT split_part(state_key, ':', 2))
+                    FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join';
+                """
+            txn.execute(query)
+            return list(txn)[0][0]
+
+        count = yield self.db.runInteraction("get_known_servers", _transact)
+
+        # We always know about ourselves, even if we have nothing in
+        # room_memberships (for example, the server is new).
+        self._known_servers_count = max([count, 1])
+        return self._known_servers_count
+
+    def _check_safe_current_state_events_membership_updated_txn(self, txn):
+        """Checks if it is safe to assume the new current_state_events
+        membership column is up to date
+        """
+
+        pending_update = self.db.simple_select_one_txn(
+            txn,
+            table="background_updates",
+            keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
+            retcols=["update_name"],
+            allow_none=True,
+        )
+
+        self._current_state_events_membership_up_to_date = not pending_update
+
+        # If the update is still running, reschedule to run.
+        if pending_update:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "_check_safe_current_state_events_membership_updated",
+                self.db.runInteraction,
+                "_check_safe_current_state_events_membership_updated",
+                self._check_safe_current_state_events_membership_updated_txn,
+            )
+
+    @cached(max_entries=100000, iterable=True)
+    def get_users_in_room(self, room_id):
+        return self.db.runInteraction(
+            "get_users_in_room", self.get_users_in_room_txn, room_id
+        )
+
+    def get_users_in_room_txn(self, txn, room_id):
+        # If we can assume current_state_events.membership is up to date
+        # then we can avoid a join, which is a Very Good Thing given how
+        # frequently this function gets called.
+        if self._current_state_events_membership_up_to_date:
+            sql = """
+                SELECT state_key FROM current_state_events
+                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+            """
+        else:
+            sql = """
+                SELECT state_key FROM room_memberships as m
+                INNER JOIN current_state_events as c
+                ON m.event_id = c.event_id
+                AND m.room_id = c.room_id
+                AND m.user_id = c.state_key
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+            """
+
+        txn.execute(sql, (room_id, Membership.JOIN))
+        return [r[0] for r in txn]
+
+    @cached(max_entries=100000)
+    def get_room_summary(self, room_id):
+        """ Get the details of a room roughly suitable for use by the room
+        summary extension to /sync. Useful when lazy loading room members.
+        Args:
+            room_id (str): The room ID to query
+        Returns:
+            Deferred[dict[str, MemberSummary]:
+                dict of membership states, pointing to a MemberSummary named tuple.
+        """
+
+        def _get_room_summary_txn(txn):
+            # first get counts.
+            # We do this all in one transaction to keep the cache small.
+            # FIXME: get rid of this when we have room_stats
+
+            # If we can assume current_state_events.membership is up to date
+            # then we can avoid a join, which is a Very Good Thing given how
+            # frequently this function gets called.
+            if self._current_state_events_membership_up_to_date:
+                # Note, rejected events will have a null membership field, so
+                # we we manually filter them out.
+                sql = """
+                    SELECT count(*), membership FROM current_state_events
+                    WHERE type = 'm.room.member' AND room_id = ?
+                        AND membership IS NOT NULL
+                    GROUP BY membership
+                """
+            else:
+                sql = """
+                    SELECT count(*), m.membership FROM room_memberships as m
+                    INNER JOIN current_state_events as c
+                    ON m.event_id = c.event_id
+                    AND m.room_id = c.room_id
+                    AND m.user_id = c.state_key
+                    WHERE c.type = 'm.room.member' AND c.room_id = ?
+                    GROUP BY m.membership
+                """
+
+            txn.execute(sql, (room_id,))
+            res = {}
+            for count, membership in txn:
+                summary = res.setdefault(membership, MemberSummary([], count))
+
+            # we order by membership and then fairly arbitrarily by event_id so
+            # heroes are consistent
+            if self._current_state_events_membership_up_to_date:
+                # Note, rejected events will have a null membership field, so
+                # we we manually filter them out.
+                sql = """
+                    SELECT state_key, membership, event_id
+                    FROM current_state_events
+                    WHERE type = 'm.room.member' AND room_id = ?
+                        AND membership IS NOT NULL
+                    ORDER BY
+                        CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+                        event_id ASC
+                    LIMIT ?
+                """
+            else:
+                sql = """
+                    SELECT c.state_key, m.membership, c.event_id
+                    FROM room_memberships as m
+                    INNER JOIN current_state_events as c USING (room_id, event_id)
+                    WHERE c.type = 'm.room.member' AND c.room_id = ?
+                    ORDER BY
+                        CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+                        c.event_id ASC
+                    LIMIT ?
+                """
+
+            # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+            txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+            for user_id, membership, event_id in txn:
+                summary = res[membership]
+                # we will always have a summary for this membership type at this
+                # point given the summary currently contains the counts.
+                members = summary.members
+                members.append((user_id, event_id))
+
+            return res
+
+        return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
+
+    def _get_user_counts_in_room_txn(self, txn, room_id):
+        """
+        Get the user count in a room by membership.
+
+        Args:
+            room_id (str)
+            membership (Membership)
+
+        Returns:
+            Deferred[int]
+        """
+        sql = """
+        SELECT m.membership, count(*) FROM room_memberships as m
+            INNER JOIN current_state_events as c USING(event_id)
+            WHERE c.type = 'm.room.member' AND c.room_id = ?
+            GROUP BY m.membership
+        """
+
+        txn.execute(sql, (room_id,))
+        return {row[0]: row[1] for row in txn}
+
+    @cached()
+    def get_invited_rooms_for_local_user(self, user_id):
+        """ Get all the rooms the *local* user is invited to
+
+        Args:
+            user_id (str): The user ID.
+        Returns:
+            A deferred list of RoomsForUser.
+        """
+
+        return self.get_rooms_for_local_user_where_membership_is(
+            user_id, [Membership.INVITE]
+        )
+
+    @defer.inlineCallbacks
+    def get_invite_for_local_user_in_room(self, user_id, room_id):
+        """Gets the invite for the given *local* user and room
+
+        Args:
+            user_id (str)
+            room_id (str)
+
+        Returns:
+            Deferred: Resolves to either a RoomsForUser or None if no invite was
+                found.
+        """
+        invites = yield self.get_invited_rooms_for_local_user(user_id)
+        for invite in invites:
+            if invite.room_id == room_id:
+                return invite
+        return None
+
+    @defer.inlineCallbacks
+    def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
+        """ Get all the rooms for this *local* user where the membership for this user
+        matches one in the membership list.
+
+        Filters out forgotten rooms.
+
+        Args:
+            user_id (str): The user ID.
+            membership_list (list): A list of synapse.api.constants.Membership
+            values which the user must be in.
+
+        Returns:
+            Deferred[list[RoomsForUser]]
+        """
+        if not membership_list:
+            return defer.succeed(None)
+
+        rooms = yield self.db.runInteraction(
+            "get_rooms_for_local_user_where_membership_is",
+            self._get_rooms_for_local_user_where_membership_is_txn,
+            user_id,
+            membership_list,
+        )
+
+        # Now we filter out forgotten rooms
+        forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+        return [room for room in rooms if room.room_id not in forgotten_rooms]
+
+    def _get_rooms_for_local_user_where_membership_is_txn(
+        self, txn, user_id, membership_list
+    ):
+        # Paranoia check.
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
+                % (user_id,),
+            )
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "c.membership", membership_list
+        )
+
+        sql = """
+            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+            FROM local_current_membership AS c
+            INNER JOIN events AS e USING (room_id, event_id)
+            WHERE
+                user_id = ?
+                AND %s
+        """ % (
+            clause,
+        )
+
+        txn.execute(sql, (user_id, *args))
+        results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+
+        return results
+
+    @cached(max_entries=500000, iterable=True)
+    def get_rooms_for_user_with_stream_ordering(self, user_id):
+        """Returns a set of room_ids the user is currently joined to.
+
+        If a remote user only returns rooms this server is currently
+        participating in.
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+            the rooms the user is in currently, along with the stream ordering
+            of the most recent join for that user and room.
+        """
+        return self.db.runInteraction(
+            "get_rooms_for_user_with_stream_ordering",
+            self._get_rooms_for_user_with_stream_ordering_txn,
+            user_id,
+        )
+
+    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+        # We use `current_state_events` here and not `local_current_membership`
+        # as a) this gets called with remote users and b) this only gets called
+        # for rooms the server is participating in.
+        if self._current_state_events_membership_up_to_date:
+            sql = """
+                SELECT room_id, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND state_key = ?
+                    AND c.membership = ?
+            """
+        else:
+            sql = """
+                SELECT room_id, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS m USING (room_id, event_id)
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND state_key = ?
+                    AND m.membership = ?
+            """
+
+        txn.execute(sql, (user_id, Membership.JOIN))
+        results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+
+        return results
+
+    async def get_users_server_still_shares_room_with(
+        self, user_ids: Collection[str]
+    ) -> Set[str]:
+        """Given a list of users return the set that the server still share a
+        room with.
+        """
+
+        if not user_ids:
+            return set()
+
+        def _get_users_server_still_shares_room_with_txn(txn):
+            sql = """
+                SELECT state_key FROM current_state_events
+                WHERE
+                    type = 'm.room.member'
+                    AND membership = 'join'
+                    AND %s
+                GROUP BY state_key
+            """
+
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "state_key", user_ids
+            )
+
+            txn.execute(sql % (clause,), args)
+
+            return {row[0] for row in txn}
+
+        return await self.db.runInteraction(
+            "get_users_server_still_shares_room_with",
+            _get_users_server_still_shares_room_with_txn,
+        )
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user(self, user_id, on_invalidate=None):
+        """Returns a set of room_ids the user is currently joined to.
+
+        If a remote user only returns rooms this server is currently
+        participating in.
+        """
+        rooms = yield self.get_rooms_for_user_with_stream_ordering(
+            user_id, on_invalidate=on_invalidate
+        )
+        return frozenset(r.room_id for r in rooms)
+
+    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
+    def get_users_who_share_room_with_user(self, user_id, cache_context):
+        """Returns the set of users who share a room with `user_id`
+        """
+        room_ids = yield self.get_rooms_for_user(
+            user_id, on_invalidate=cache_context.invalidate
+        )
+
+        user_who_share_room = set()
+        for room_id in room_ids:
+            user_ids = yield self.get_users_in_room(
+                room_id, on_invalidate=cache_context.invalidate
+            )
+            user_who_share_room.update(user_ids)
+
+        return user_who_share_room
+
+    @defer.inlineCallbacks
+    def get_joined_users_from_context(self, event, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        current_state_ids = yield context.get_current_state_ids()
+        result = yield self._get_joined_users_from_context(
+            event.room_id, state_group, current_state_ids, event=event, context=context
+        )
+        return result
+
+    @defer.inlineCallbacks
+    def get_joined_users_from_state(self, room_id, state_entry):
+        state_group = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        with Measure(self._clock, "get_joined_users_from_state"):
+            return (
+                yield self._get_joined_users_from_context(
+                    room_id, state_group, state_entry.state, context=state_entry
+                )
+            )
+
+    @cachedInlineCallbacks(
+        num_args=2, cache_context=True, iterable=True, max_entries=100000
+    )
+    def _get_joined_users_from_context(
+        self,
+        room_id,
+        state_group,
+        current_state_ids,
+        cache_context,
+        event=None,
+        context=None,
+    ):
+        # We don't use `state_group`, it's there so that we can cache based
+        # on it. However, it's important that it's never None, since two current_states
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        users_in_room = {}
+        member_event_ids = [
+            e_id
+            for key, e_id in iteritems(current_state_ids)
+            if key[0] == EventTypes.Member
+        ]
+
+        if context is not None:
+            # If we have a context with a delta from a previous state group,
+            # check if we also have the result from the previous group in cache.
+            # If we do then we can reuse that result and simply update it with
+            # any membership changes in `delta_ids`
+            if context.prev_group and context.delta_ids:
+                prev_res = self._get_joined_users_from_context.cache.get(
+                    (room_id, context.prev_group), None
+                )
+                if prev_res and isinstance(prev_res, dict):
+                    users_in_room = dict(prev_res)
+                    member_event_ids = [
+                        e_id
+                        for key, e_id in iteritems(context.delta_ids)
+                        if key[0] == EventTypes.Member
+                    ]
+                    for etype, state_key in context.delta_ids:
+                        if etype == EventTypes.Member:
+                            users_in_room.pop(state_key, None)
+
+        # We check if we have any of the member event ids in the event cache
+        # before we ask the DB
+
+        # We don't update the event cache hit ratio as it completely throws off
+        # the hit ratio counts. After all, we don't populate the cache if we
+        # miss it here
+        event_map = self._get_events_from_cache(
+            member_event_ids, allow_rejected=False, update_metrics=False
+        )
+
+        missing_member_event_ids = []
+        for event_id in member_event_ids:
+            ev_entry = event_map.get(event_id)
+            if ev_entry:
+                if ev_entry.event.membership == Membership.JOIN:
+                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
+                        display_name=ev_entry.event.content.get("displayname", None),
+                        avatar_url=ev_entry.event.content.get("avatar_url", None),
+                    )
+            else:
+                missing_member_event_ids.append(event_id)
+
+        if missing_member_event_ids:
+            event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+                missing_member_event_ids
+            )
+            users_in_room.update((row for row in event_to_memberships.values() if row))
+
+        if event is not None and event.type == EventTypes.Member:
+            if event.membership == Membership.JOIN:
+                if event.event_id in member_event_ids:
+                    users_in_room[event.state_key] = ProfileInfo(
+                        display_name=event.content.get("displayname", None),
+                        avatar_url=event.content.get("avatar_url", None),
+                    )
+
+        return users_in_room
+
+    @cached(max_entries=10000)
+    def _get_joined_profile_from_event_id(self, event_id):
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_joined_profile_from_event_id",
+        list_name="event_ids",
+        inlineCallbacks=True,
+    )
+    def _get_joined_profiles_from_event_ids(self, event_ids):
+        """For given set of member event_ids check if they point to a join
+        event and if so return the associated user and profile info.
+
+        Args:
+            event_ids (Iterable[str]): The member event IDs to lookup
+
+        Returns:
+            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+            to `user_id` and ProfileInfo (or None if not join event).
+        """
+
+        rows = yield self.db.simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=event_ids,
+            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            keyvalues={"membership": Membership.JOIN},
+            batch_size=500,
+            desc="_get_membership_from_event_ids",
+        )
+
+        return {
+            row["event_id"]: (
+                row["user_id"],
+                ProfileInfo(
+                    avatar_url=row["avatar_url"], display_name=row["display_name"]
+                ),
+            )
+            for row in rows
+        }
+
+    @cachedInlineCallbacks(max_entries=10000)
+    def is_host_joined(self, room_id, host):
+        if "%" in host or "_" in host:
+            raise Exception("Invalid host name")
+
+        sql = """
+            SELECT state_key FROM current_state_events AS c
+            INNER JOIN room_memberships AS m USING (event_id)
+            WHERE m.membership = 'join'
+                AND type = 'm.room.member'
+                AND c.room_id = ?
+                AND state_key LIKE ?
+            LIMIT 1
+        """
+
+        # We do need to be careful to ensure that host doesn't have any wild cards
+        # in it, but we checked above for known ones and we'll check below that
+        # the returned user actually has the correct domain.
+        like_clause = "%:" + host
+
+        rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
+
+        if not rows:
+            return False
+
+        user_id = rows[0][0]
+        if get_domain_from_id(user_id) != host:
+            # This can only happen if the host name has something funky in it
+            raise Exception("Invalid host name")
+
+        return True
+
+    @cachedInlineCallbacks()
+    def was_host_joined(self, room_id, host):
+        """Check whether the server is or ever was in the room.
+
+        Args:
+            room_id (str)
+            host (str)
+
+        Returns:
+            Deferred: Resolves to True if the host is/was in the room, otherwise
+            False.
+        """
+        if "%" in host or "_" in host:
+            raise Exception("Invalid host name")
+
+        sql = """
+            SELECT user_id FROM room_memberships
+            WHERE room_id = ?
+                AND user_id LIKE ?
+                AND membership = 'join'
+            LIMIT 1
+        """
+
+        # We do need to be careful to ensure that host doesn't have any wild cards
+        # in it, but we checked above for known ones and we'll check below that
+        # the returned user actually has the correct domain.
+        like_clause = "%:" + host
+
+        rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
+
+        if not rows:
+            return False
+
+        user_id = rows[0][0]
+        if get_domain_from_id(user_id) != host:
+            # This can only happen if the host name has something funky in it
+            raise Exception("Invalid host name")
+
+        return True
+
+    @defer.inlineCallbacks
+    def get_joined_hosts(self, room_id, state_entry):
+        state_group = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        with Measure(self._clock, "get_joined_hosts"):
+            return (
+                yield self._get_joined_hosts(
+                    room_id, state_group, state_entry.state, state_entry=state_entry
+                )
+            )
+
+    @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+    # @defer.inlineCallbacks
+    def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        cache = yield self._get_joined_hosts_cache(room_id)
+        joined_hosts = yield cache.get_destinations(state_entry)
+
+        return joined_hosts
+
+    @cached(max_entries=10000)
+    def _get_joined_hosts_cache(self, room_id):
+        return _JoinedHostsCache(self, room_id)
+
+    @cachedInlineCallbacks(num_args=2)
+    def did_forget(self, user_id, room_id):
+        """Returns whether user_id has elected to discard history for room_id.
+
+        Returns False if they have since re-joined."""
+
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  COUNT(*)"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  forgotten = 0"
+            )
+            txn.execute(sql, (user_id, room_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+
+        count = yield self.db.runInteraction("did_forget_membership", f)
+        return count == 0
+
+    @cached()
+    def get_forgotten_rooms_for_user(self, user_id):
+        """Gets all rooms the user has forgotten.
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[set[str]]
+        """
+
+        def _get_forgotten_rooms_for_user_txn(txn):
+            # This is a slightly convoluted query that first looks up all rooms
+            # that the user has forgotten in the past, then rechecks that list
+            # to see if any have subsequently been updated. This is done so that
+            # we can use a partial index on `forgotten = 1` on the assumption
+            # that few users will actually forget many rooms.
+            #
+            # Note that a room is considered "forgotten" if *all* membership
+            # events for that user and room have the forgotten field set (as
+            # when a user forgets a room we update all rows for that user and
+            # room, not just the current one).
+            sql = """
+                SELECT room_id, (
+                    SELECT count(*) FROM room_memberships
+                    WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
+                ) AS count
+                FROM room_memberships AS m
+                WHERE user_id = ? AND forgotten = 1
+                GROUP BY room_id, user_id;
+            """
+            txn.execute(sql, (user_id,))
+            return {row[0] for row in txn if row[1] == 0}
+
+        return self.db.runInteraction(
+            "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
+        )
+
+    @defer.inlineCallbacks
+    def get_rooms_user_has_been_in(self, user_id):
+        """Get all rooms that the user has ever been in.
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[set[str]]: Set of room IDs.
+        """
+
+        room_ids = yield self.db.simple_select_onecol(
+            table="room_memberships",
+            keyvalues={"membership": Membership.JOIN, "user_id": user_id},
+            retcol="room_id",
+            desc="get_rooms_user_has_been_in",
+        )
+
+        return set(room_ids)
+
+    def get_membership_from_event_ids(
+        self, member_event_ids: Iterable[str]
+    ) -> List[dict]:
+        """Get user_id and membership of a set of event IDs.
+        """
+
+        return self.db.simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=member_event_ids,
+            retcols=("user_id", "membership", "event_id"),
+            keyvalues={},
+            batch_size=500,
+            desc="get_membership_from_event_ids",
+        )
+
+    async def is_local_host_in_room_ignoring_users(
+        self, room_id: str, ignore_users: Collection[str]
+    ) -> bool:
+        """Check if there are any local users, excluding those in the given
+        list, in the room.
+        """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "user_id", ignore_users
+        )
+
+        sql = """
+            SELECT 1 FROM local_current_membership
+            WHERE
+                room_id = ? AND membership = ?
+                AND NOT (%s)
+                LIMIT 1
+        """ % (
+            clause,
+        )
+
+        def _is_local_host_in_room_ignoring_users_txn(txn):
+            txn.execute(sql, (room_id, Membership.JOIN, *args))
+
+            return bool(txn.fetchone())
+
+        return await self.db.runInteraction(
+            "is_local_host_in_room_ignoring_users",
+            _is_local_host_in_room_ignoring_users_txn,
+        )
+
+
+class RoomMemberBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        self.db.updates.register_background_update_handler(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+        )
+        self.db.updates.register_background_update_handler(
+            _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+            self._background_current_state_membership,
+        )
+        self.db.updates.register_background_index_update(
+            "room_membership_forgotten_idx",
+            index_name="room_memberships_user_room_forgotten",
+            table="room_memberships",
+            columns=["user_id", "room_id"],
+            where_clause="forgotten = 1",
+        )
+
+    @defer.inlineCallbacks
+    def _background_add_membership_profile(self, progress, batch_size):
+        target_min_stream_id = progress.get(
+            "target_min_stream_id_inclusive", self._min_stream_order_on_start
+        )
+        max_stream_id = progress.get(
+            "max_stream_id_exclusive", self._stream_order_on_start + 1
+        )
+
+        INSERT_CLUMP_SIZE = 1000
+
+        def add_membership_profile_txn(txn):
+            sql = """
+                SELECT stream_ordering, event_id, events.room_id, event_json.json
+                FROM events
+                INNER JOIN event_json USING (event_id)
+                INNER JOIN room_memberships USING (event_id)
+                WHERE ? <= stream_ordering AND stream_ordering < ?
+                AND type = 'm.room.member'
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+            rows = self.db.cursor_to_dict(txn)
+            if not rows:
+                return 0
+
+            min_stream_id = rows[-1]["stream_ordering"]
+
+            to_update = []
+            for row in rows:
+                event_id = row["event_id"]
+                room_id = row["room_id"]
+                try:
+                    event_json = json.loads(row["json"])
+                    content = event_json["content"]
+                except Exception:
+                    continue
+
+                display_name = content.get("displayname", None)
+                avatar_url = content.get("avatar_url", None)
+
+                if display_name or avatar_url:
+                    to_update.append((display_name, avatar_url, event_id, room_id))
+
+            to_update_sql = """
+                UPDATE room_memberships SET display_name = ?, avatar_url = ?
+                WHERE event_id = ? AND room_id = ?
+            """
+            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
+                clump = to_update[index : index + INSERT_CLUMP_SIZE]
+                txn.executemany(to_update_sql, clump)
+
+            progress = {
+                "target_min_stream_id_inclusive": target_min_stream_id,
+                "max_stream_id_exclusive": min_stream_id,
+            }
+
+            self.db.updates._background_update_progress_txn(
+                txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
+            )
+
+            return len(rows)
+
+        result = yield self.db.runInteraction(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
+        )
+
+        if not result:
+            yield self.db.updates._end_background_update(
+                _MEMBERSHIP_PROFILE_UPDATE_NAME
+            )
+
+        return result
+
+    @defer.inlineCallbacks
+    def _background_current_state_membership(self, progress, batch_size):
+        """Update the new membership column on current_state_events.
+
+        This works by iterating over all rooms in alphebetical order.
+        """
+
+        def _background_current_state_membership_txn(txn, last_processed_room):
+            processed = 0
+            while processed < batch_size:
+                txn.execute(
+                    """
+                        SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
+                    """,
+                    (last_processed_room,),
+                )
+                row = txn.fetchone()
+                if not row or not row[0]:
+                    return processed, True
+
+                (next_room,) = row
+
+                sql = """
+                    UPDATE current_state_events
+                    SET membership = (
+                        SELECT membership FROM room_memberships
+                        WHERE event_id = current_state_events.event_id
+                    )
+                    WHERE room_id = ?
+                """
+                txn.execute(sql, (next_room,))
+                processed += txn.rowcount
+
+                last_processed_room = next_room
+
+            self.db.updates._background_update_progress_txn(
+                txn,
+                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+                {"last_processed_room": last_processed_room},
+            )
+
+            return processed, False
+
+        # If we haven't got a last processed room then just use the empty
+        # string, which will compare before all room IDs correctly.
+        last_processed_room = progress.get("last_processed_room", "")
+
+        row_count, finished = yield self.db.runInteraction(
+            "_background_current_state_membership_update",
+            _background_current_state_membership_txn,
+            last_processed_room,
+        )
+
+        if finished:
+            yield self.db.updates._end_background_update(
+                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
+            )
+
+        return row_count
+
+
+class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberStore, self).__init__(database, db_conn, hs)
+
+    def forget(self, user_id, room_id):
+        """Indicate that user_id wishes to discard history for room_id."""
+
+        def f(txn):
+            sql = (
+                "UPDATE"
+                "  room_memberships"
+                " SET"
+                "  forgotten = 1"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+            )
+            txn.execute(sql, (user_id, room_id))
+
+            self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+            self._invalidate_cache_and_stream(
+                txn, self.get_forgotten_rooms_for_user, (user_id,)
+            )
+
+        return self.db.runInteraction("forget_membership", f)
+
+
+class _JoinedHostsCache(object):
+    """Cache for joined hosts in a room that is optimised to handle updates
+    via state deltas.
+    """
+
+    def __init__(self, store, room_id):
+        self.store = store
+        self.room_id = room_id
+
+        self.hosts_to_joined_users = {}
+
+        self.state_group = object()
+
+        self.linearizer = Linearizer("_JoinedHostsCache")
+
+        self._len = 0
+
+    @defer.inlineCallbacks
+    def get_destinations(self, state_entry):
+        """Get set of destinations for a state entry
+
+        Args:
+            state_entry(synapse.state._StateCacheEntry)
+        """
+        if state_entry.state_group == self.state_group:
+            return frozenset(self.hosts_to_joined_users)
+
+        with (yield self.linearizer.queue(())):
+            if state_entry.state_group == self.state_group:
+                pass
+            elif state_entry.prev_group == self.state_group:
+                for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
+                    if typ != EventTypes.Member:
+                        continue
+
+                    host = intern_string(get_domain_from_id(state_key))
+                    user_id = state_key
+                    known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+                    event = yield self.store.get_event(event_id)
+                    if event.membership == Membership.JOIN:
+                        known_joins.add(user_id)
+                    else:
+                        known_joins.discard(user_id)
+
+                        if not known_joins:
+                            self.hosts_to_joined_users.pop(host, None)
+            else:
+                joined_users = yield self.store.get_joined_users_from_state(
+                    self.room_id, state_entry
+                )
+
+                self.hosts_to_joined_users = {}
+                for user_id in joined_users:
+                    host = intern_string(get_domain_from_id(user_id))
+                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+            if state_entry.state_group:
+                self.state_group = state_entry.state_group
+            else:
+                self.state_group = object()
+            self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
+        return frozenset(self.hosts_to_joined_users)
+
+    def __len__(self):
+        return self._len
diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql
index 5964c5aaac..5964c5aaac 100644
--- a/synapse/storage/schema/delta/12/v12.sql
+++ b/synapse/storage/data_stores/main/schema/delta/12/v12.sql
diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql
index f8649e5d99..f8649e5d99 100644
--- a/synapse/storage/schema/delta/13/v13.sql
+++ b/synapse/storage/data_stores/main/schema/delta/13/v13.sql
diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql
index a831920da6..a831920da6 100644
--- a/synapse/storage/schema/delta/14/v14.sql
+++ b/synapse/storage/data_stores/main/schema/delta/14/v14.sql
diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
index e4f5e76aec..e4f5e76aec 100644
--- a/synapse/storage/schema/delta/15/appservice_txns.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
index 6b8d0f1ca7..6b8d0f1ca7 100644
--- a/synapse/storage/schema/delta/15/presence_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql
index 9523d2bcc3..9523d2bcc3 100644
--- a/synapse/storage/schema/delta/15/v15.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/v15.sql
diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
index a48f215170..a48f215170 100644
--- a/synapse/storage/schema/delta/16/events_order_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
index 7a15265cb1..7a15265cb1 100644
--- a/synapse/storage/schema/delta/16/remote_media_cache_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
index 65c97b5e2f..65c97b5e2f 100644
--- a/synapse/storage/schema/delta/16/remove_duplicates.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
index f82486132b..f82486132b 100644
--- a/synapse/storage/schema/delta/16/room_alias_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
index 5b8de52c33..5b8de52c33 100644
--- a/synapse/storage/schema/delta/16/unique_constraints.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql
index cd0709250d..cd0709250d 100644
--- a/synapse/storage/schema/delta/16/users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/users.sql
diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
index 7c9a90e27f..7c9a90e27f 100644
--- a/synapse/storage/schema/delta/17/drop_indexes.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
index 70b247a06b..70b247a06b 100644
--- a/synapse/storage/schema/delta/17/server_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
index c17715ac80..c17715ac80 100644
--- a/synapse/storage/schema/delta/17/user_threepids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
diff --git a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
index 6e0871c92b..6e0871c92b 100644
--- a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql
+++ b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
index 18b97b4332..18b97b4332 100644
--- a/synapse/storage/schema/delta/19/event_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
index e0ac49d1ec..e0ac49d1ec 100644
--- a/synapse/storage/schema/delta/20/dummy.sql
+++ b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py
index 3edfcfd783..3edfcfd783 100644
--- a/synapse/storage/schema/delta/20/pushers.py
+++ b/synapse/storage/data_stores/main/schema/delta/20/pushers.py
diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
index 4c2fb20b77..4c2fb20b77 100644
--- a/synapse/storage/schema/delta/21/end_to_end_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
index d070845477..d070845477 100644
--- a/synapse/storage/schema/delta/21/receipts.sql
+++ b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
index bfc0b3bcaa..bfc0b3bcaa 100644
--- a/synapse/storage/schema/delta/22/receipts_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
diff --git a/synapse/storage/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
index 87edfa454c..87edfa454c 100644
--- a/synapse/storage/schema/delta/22/user_threepids_unique.sql
+++ b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
index acea7483bd..acea7483bd 100644
--- a/synapse/storage/schema/delta/24/stats_reporting.sql
+++ b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py
index 4b2ffd35fd..4b2ffd35fd 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py
diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
index 1ea389b471..1ea389b471 100644
--- a/synapse/storage/schema/delta/25/guest_access.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
index f468fc1897..f468fc1897 100644
--- a/synapse/storage/schema/delta/25/history_visibility.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql
index 7a32ce68e4..7a32ce68e4 100644
--- a/synapse/storage/schema/delta/25/tags.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/tags.sql
diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
index e395de2b5e..e395de2b5e 100644
--- a/synapse/storage/schema/delta/26/account_data.sql
+++ b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
index bf0558b5b3..bf0558b5b3 100644
--- a/synapse/storage/schema/delta/27/account_data.sql
+++ b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
index e2094f37fe..e2094f37fe 100644
--- a/synapse/storage/schema/delta/27/forgotten_memberships.sql
+++ b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py
index 414f9f5aa0..414f9f5aa0 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
index 4d519849df..4d519849df 100644
--- a/synapse/storage/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
index 36609475f1..36609475f1 100644
--- a/synapse/storage/schema/delta/28/events_room_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
index 6c1fd68c5b..6c1fd68c5b 100644
--- a/synapse/storage/schema/delta/28/public_roms_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
index cb84c69baa..cb84c69baa 100644
--- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
index 3e4a9ab455..3e4a9ab455 100644
--- a/synapse/storage/schema/delta/28/upgrade_times.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
index 21d2b420bf..21d2b420bf 100644
--- a/synapse/storage/schema/delta/28/users_is_guest.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
index 84b21cf813..84b21cf813 100644
--- a/synapse/storage/schema/delta/29/push_actions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
index c9d0dde638..c9d0dde638 100644
--- a/synapse/storage/schema/delta/30/alias_creator.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
index 9b95411fb6..9b95411fb6 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
index 712c454aa1..712c454aa1 100644
--- a/synapse/storage/schema/delta/30/deleted_pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
index 606bbb037d..606bbb037d 100644
--- a/synapse/storage/schema/delta/30/presence_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
index f09db4faa6..f09db4faa6 100644
--- a/synapse/storage/schema/delta/30/public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
index 735aa8d5f6..735aa8d5f6 100644
--- a/synapse/storage/schema/delta/30/push_rule_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
index 0dd2f1360c..0dd2f1360c 100644
--- a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql
index 2c57846d5a..2c57846d5a 100644
--- a/synapse/storage/schema/delta/31/invites.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/invites.sql
diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
index 9efb4280eb..9efb4280eb 100644
--- a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py
index 9bb504aad5..9bb504aad5 100644
--- a/synapse/storage/schema/delta/31/pushers.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/pushers.py
diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
index a82add88fd..a82add88fd 100644
--- a/synapse/storage/schema/delta/31/pushers_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
index 7d8ca5f93f..7d8ca5f93f 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql
index 1dd0f9e170..1dd0f9e170 100644
--- a/synapse/storage/schema/delta/32/events.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/events.sql
diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql
index 36f37b11c8..36f37b11c8 100644
--- a/synapse/storage/schema/delta/32/openid.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/openid.sql
diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
index d86d30c13c..d86d30c13c 100644
--- a/synapse/storage/schema/delta/32/pusher_throttle.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
index 4219cdd06a..2de50d408c 100644
--- a/synapse/storage/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
@@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
 DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
 DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
 DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
-DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
 DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
 DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
 
diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql
index d13609776f..d13609776f 100644
--- a/synapse/storage/schema/delta/32/reports.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/reports.sql
diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
index 61ad3fe3e8..61ad3fe3e8 100644
--- a/synapse/storage/schema/delta/33/access_tokens_device_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql
index eca7268d82..eca7268d82 100644
--- a/synapse/storage/schema/delta/33/devices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
index aa4a3b9f2f..aa4a3b9f2f 100644
--- a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
index 6671573398..6671573398 100644
--- a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
index bff1256a7b..bff1256a7b 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..a26057dfb6 100644
--- a/synapse/storage/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
index 473f75a78e..473f75a78e 100644
--- a/synapse/storage/schema/delta/33/user_ips_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
index 69e16eda0f..69e16eda0f 100644
--- a/synapse/storage/schema/delta/34/appservice_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
index cf09e43e2b..cf09e43e2b 100644
--- a/synapse/storage/schema/delta/34/cache_stream.py
+++ b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
index e68844c74a..e68844c74a 100644
--- a/synapse/storage/schema/delta/34/device_inbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
index 0d9fe1a99a..0d9fe1a99a 100644
--- a/synapse/storage/schema/delta/34/push_display_name_rename.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
index 67d505e68b..67d505e68b 100644
--- a/synapse/storage/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
index 6cd123027b..6cd123027b 100644
--- a/synapse/storage/schema/delta/35/contains_url.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
index 17e6c43105..17e6c43105 100644
--- a/synapse/storage/schema/delta/35/device_outbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
index 7ab7d942e2..7ab7d942e2 100644
--- a/synapse/storage/schema/delta/35/device_stream_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
index 2e836d8e9c..2e836d8e9c 100644
--- a/synapse/storage/schema/delta/35/event_push_actions_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
index dd2bf2e28a..dd2bf2e28a 100644
--- a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
index 2b945d8a57..2b945d8a57 100644
--- a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
index 90d8fd18f9..90d8fd18f9 100644
--- a/synapse/storage/schema/delta/36/readd_public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
index a377884169..a377884169 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
index cf7a90dd10..cf7a90dd10 100644
--- a/synapse/storage/schema/delta/37/user_threepids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
index 515e6b8e84..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
index 74bdc49073..74bdc49073 100644
--- a/synapse/storage/schema/delta/39/appservice_room_list.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
index 00be801e90..00be801e90 100644
--- a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
index de2ad93e5c..de2ad93e5c 100644
--- a/synapse/storage/schema/delta/39/event_push_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
index 5af814290b..5af814290b 100644
--- a/synapse/storage/schema/delta/39/federation_out_position.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
index 1bf911c8ab..1bf911c8ab 100644
--- a/synapse/storage/schema/delta/39/membership_profile.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
index 7ffa189f39..7ffa189f39 100644
--- a/synapse/storage/schema/delta/40/current_state_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
index b9fe1f0480..b9fe1f0480 100644
--- a/synapse/storage/schema/delta/40/device_inbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
index dd6dcb65f1..dd6dcb65f1 100644
--- a/synapse/storage/schema/delta/40/device_list_streams.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
index 3918f0b794..3918f0b794 100644
--- a/synapse/storage/schema/delta/40/event_push_summary.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
index 054a223f14..054a223f14 100644
--- a/synapse/storage/schema/delta/40/pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
index b7bee8b692..b7bee8b692 100644
--- a/synapse/storage/schema/delta/41/device_list_stream_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
index 62f0b9892b..62f0b9892b 100644
--- a/synapse/storage/schema/delta/41/device_outbound_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
index 5d9cfecf36..5d9cfecf36 100644
--- a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
index a194bf0238..a194bf0238 100644
--- a/synapse/storage/schema/delta/41/ratelimit.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
index d28851aff8..d28851aff8 100644
--- a/synapse/storage/schema/delta/42/current_state_delta.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
index 9ab8c14fa3..9ab8c14fa3 100644
--- a/synapse/storage/schema/delta/42/device_list_last_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
index b8821ac759..b8821ac759 100644
--- a/synapse/storage/schema/delta/42/event_auth_state_only.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
index 506f326f4d..506f326f4d 100644
--- a/synapse/storage/schema/delta/42/user_dir.py
+++ b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
index 0e3cd143ff..0e3cd143ff 100644
--- a/synapse/storage/schema/delta/43/blocked_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
index 630907ec4f..630907ec4f 100644
--- a/synapse/storage/schema/delta/43/quarantine_media.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
index 45ebe020da..45ebe020da 100644
--- a/synapse/storage/schema/delta/43/url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
index ee7062abe4..ee7062abe4 100644
--- a/synapse/storage/schema/delta/43/user_share.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
index b12f9b2ebf..b12f9b2ebf 100644
--- a/synapse/storage/schema/delta/44/expire_url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
index b2333848a0..b2333848a0 100644
--- a/synapse/storage/schema/delta/45/group_server.sql
+++ b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
index e5ddc84df0..e5ddc84df0 100644
--- a/synapse/storage/schema/delta/45/profile_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
index 68c48a89a9..68c48a89a9 100644
--- a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
index bb307889c1..bb307889c1 100644
--- a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
index 097679bc9a..097679bc9a 100644
--- a/synapse/storage/schema/delta/46/group_server.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
index bbfc7f5d1a..bbfc7f5d1a 100644
--- a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
index cb0d5a2576..cb0d5a2576 100644
--- a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
index d9505f8da1..d9505f8da1 100644
--- a/synapse/storage/schema/delta/46/user_dir_typos.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
index f505fb22b5..f505fb22b5 100644
--- a/synapse/storage/schema/delta/47/last_access_media.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
index 31d7a817eb..31d7a817eb 100644
--- a/synapse/storage/schema/delta/47/postgres_fts_gin.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
index edccf4a96f..edccf4a96f 100644
--- a/synapse/storage/schema/delta/47/push_actions_staging.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
index 5237491506..5237491506 100644
--- a/synapse/storage/schema/delta/48/add_user_consent.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
index 9248b0b24a..9248b0b24a 100644
--- a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
index e9013a6969..e9013a6969 100644
--- a/synapse/storage/schema/delta/48/deactivated_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
index 49f5f2c003..49f5f2c003 100644
--- a/synapse/storage/schema/delta/48/group_unique_indexes.py
+++ b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
index ce26eaf0c9..ce26eaf0c9 100644
--- a/synapse/storage/schema/delta/48/groups_joinable.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
index 14dcf18d73..14dcf18d73 100644
--- a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
index 3dd478196f..3dd478196f 100644
--- a/synapse/storage/schema/delta/49/add_user_daily_visits.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
index 3a4ed59b5b..3a4ed59b5b 100644
--- a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
index c93ae47532..c93ae47532 100644
--- a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
index 5d8641a9ab..5d8641a9ab 100644
--- a/synapse/storage/schema/delta/50/erasure_store.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
index b1684a8441..b1684a8441 100644
--- a/synapse/storage/schema/delta/50/make_event_content_nullable.py
+++ b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
index c0e66a697d..c0e66a697d 100644
--- a/synapse/storage/schema/delta/51/e2e_room_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
index c9d537d5a3..c9d537d5a3 100644
--- a/synapse/storage/schema/delta/51/monthly_active_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
index 91e03d13e1..91e03d13e1 100644
--- a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
index bfa49e6f92..bfa49e6f92 100644
--- a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
diff --git a/synapse/storage/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
index db687cccae..db687cccae 100644
--- a/synapse/storage/schema/delta/52/e2e_room_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
index 88ec2f83e5..88ec2f83e5 100644
--- a/synapse/storage/schema/delta/53/add_user_type_to_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
diff --git a/synapse/storage/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
index e372f5a44a..e372f5a44a 100644
--- a/synapse/storage/schema/delta/53/drop_sent_transactions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
index 1d977c2834..1d977c2834 100644
--- a/synapse/storage/schema/delta/53/event_format_version.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
index ffcc896b58..ffcc896b58 100644
--- a/synapse/storage/schema/delta/53/user_dir_populate.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
index b812c5794f..b812c5794f 100644
--- a/synapse/storage/schema/delta/53/user_ips_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
index 5831b1a6f8..5831b1a6f8 100644
--- a/synapse/storage/schema/delta/53/user_share.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
diff --git a/synapse/storage/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
index 80c2c573b6..80c2c573b6 100644
--- a/synapse/storage/schema/delta/53/user_threepid_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
index f7827ca6d2..f7827ca6d2 100644
--- a/synapse/storage/schema/delta/53/users_in_public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
index 0adb2ad55e..0adb2ad55e 100644
--- a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
index c01aa9d2d9..c01aa9d2d9 100644
--- a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
index b062ec840c..b062ec840c 100644
--- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
index dbbe682697..dbbe682697 100644
--- a/synapse/storage/schema/delta/54/drop_legacy_tables.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
index e6ee70c623..e6ee70c623 100644
--- a/synapse/storage/schema/delta/54/drop_presence_list.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql
index 134862b870..134862b870 100644
--- a/synapse/storage/schema/delta/54/relations.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/relations.sql
diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql
index 652e58308e..652e58308e 100644
--- a/synapse/storage/schema/delta/54/stats.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/stats.sql
diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
index 3b2d48447f..3b2d48447f 100644
--- a/synapse/storage/schema/delta/54/stats2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
index 4590604bfd..4590604bfd 100644
--- a/synapse/storage/schema/delta/55/access_token_expiry.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
index a8eced2e0a..a8eced2e0a 100644
--- a/synapse/storage/schema/delta/55/track_threepid_validations.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
index dabdde489b..dabdde489b 100644
--- a/synapse/storage/schema/delta/55/users_alter_deactivated.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
diff --git a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
index 41807eb1e7..41807eb1e7 100644
--- a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
index 473018676f..473018676f 100644
--- a/synapse/storage/schema/delta/56/current_state_events_membership.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
index 3133d42d4a..3133d42d4a 100644
--- a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
new file mode 100644
index 0000000000..1d2ddb1b1a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* delete room keys that belong to deleted room key version, or to room key
+ * versions that don't exist (anymore)
+ */
+DELETE FROM e2e_room_keys
+WHERE version NOT IN (
+  SELECT version
+  FROM e2e_room_keys_versions
+  WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id
+  AND e2e_room_keys_versions.deleted = 0
+);
diff --git a/synapse/storage/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
index f00889290b..f00889290b 100644
--- a/synapse/storage/schema/delta/56/destinations_failure_ts.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
new file mode 100644
index 0000000000..b9bbb18a91
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We want to store large retry intervals so we upgrade the column from INT
+-- to BIGINT. We don't need to do this on SQLite.
+ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
new file mode 100644
index 0000000000..c2f557fde9
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This line already existed in deltas/35/device_stream_id but was not included in the
+-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist
+INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS (
+    SELECT * from device_max_stream_id
+);
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
new file mode 100644
index 0000000000..dfa902d0ba
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track last seen information for a device in the devices table, rather
+-- than relying on it being in the user_ips table (which we want to be able
+-- to purge old entries from)
+ALTER TABLE devices ADD COLUMN last_seen BIGINT;
+ALTER TABLE devices ADD COLUMN ip TEXT;
+ALTER TABLE devices ADD COLUMN user_agent TEXT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('devices_last_seen', '{}');
diff --git a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
index 9f09922c67..9f09922c67 100644
--- a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
new file mode 100644
index 0000000000..81a36a8b1d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+    event_id TEXT PRIMARY KEY,
+    expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
new file mode 100644
index 0000000000..5e29c1da19
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -0,0 +1,30 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- room_id and topoligical_ordering are denormalised from the events table in order to
+-- make the index work.
+CREATE TABLE IF NOT EXISTS event_labels (
+    event_id TEXT,
+    label TEXT,
+    room_id TEXT NOT NULL,
+    topological_ordering BIGINT NOT NULL,
+    PRIMARY KEY(event_id, label)
+);
+
+
+-- This index enables an event pagination looking for a particular label to index the
+-- event_labels table first, which is much quicker than scanning the events table and then
+-- filtering by label, if the label is rarely used relative to the size of the room.
+CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
new file mode 100644
index 0000000000..5f5e0499ae
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('event_store_labels', '{}');
diff --git a/synapse/storage/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
index 014cb3b538..014cb3b538 100644
--- a/synapse/storage/schema/delta/56/fix_room_keys_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
new file mode 100644
index 0000000000..67f8b20297
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- device list needs to know which ones are "real" devices, and which ones are
+-- just used to avoid collisions
+ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
new file mode 100644
index 0000000000..e8b1fd35d8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
@@ -0,0 +1,42 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Change the hidden column from a default value of FALSE to a default value of
+ * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
+ * string 'FALSE', which is truthy.
+ *
+ * Since sqlite doesn't allow us to just change the default value, we have to
+ * recreate the table, copy the data, fix the rows that have incorrect data, and
+ * replace the old table with the new table.
+ */
+
+CREATE TABLE IF NOT EXISTS devices2 (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    display_name TEXT,
+    last_seen BIGINT,
+    ip TEXT,
+    user_agent TEXT,
+    hidden BOOLEAN DEFAULT 0,
+    CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
+
+INSERT INTO devices2 SELECT * FROM devices;
+
+UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
+
+DROP TABLE devices;
+
+ALTER TABLE devices2 RENAME TO devices;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
new file mode 100644
index 0000000000..4f24c1405d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
@@ -0,0 +1,29 @@
+/* Copyright 2019 Werner Sembach
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made.
+DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
new file mode 100644
index 0000000000..7be31ffebb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id);
diff --git a/synapse/storage/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
index fe51b02309..ea95db0ed7 100644
--- a/synapse/storage/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -14,4 +14,3 @@
  */
 
 ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
-CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
new file mode 100644
index 0000000000..49ce35d794
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
new file mode 100644
index 0000000000..67471f3ef5
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- There was a bug where we may have updated censored redactions as bytes,
+-- which can (somehow) cause json to be inserted hex encoded. These updates go
+-- and undoes any such hex encoded JSON.
+
+INSERT into background_updates (update_name, progress_json)
+  VALUES ('event_fix_redactions_bytes_create_index', '{}');
+
+INSERT into background_updates (update_name, progress_json, depends_on)
+  VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
new file mode 100644
index 0000000000..aeb17813d3
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Now that #6232 is a thing, we can remove old rooms from the directory.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('remove_tombstoned_rooms_from_directory', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
new file mode 100644
index 0000000000..7d70dd071e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- store the current etag of backup version
+ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT;
diff --git a/synapse/storage/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
index 92ab1f5e65..92ab1f5e65 100644
--- a/synapse/storage/schema/delta/56/room_membership_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
new file mode 100644
index 0000000000..ee6cdf7a14
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
@@ -0,0 +1,33 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks the retention policy of a room.
+-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in
+-- the room's retention policy state event.
+-- If a room doesn't have a retention policy state event in its state, both max_lifetime
+-- and min_lifetime are NULL.
+CREATE TABLE IF NOT EXISTS room_retention(
+    room_id TEXT,
+    event_id TEXT,
+    min_lifetime BIGINT,
+    max_lifetime BIGINT,
+
+    PRIMARY KEY(room_id, event_id)
+);
+
+CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime);
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('insert_room_retention', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
new file mode 100644
index 0000000000..5c5fffcafb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
@@ -0,0 +1,56 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- cross-signing keys
+CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys (
+    user_id TEXT NOT NULL,
+    -- the type of cross-signing key (master, user_signing, or self_signing)
+    keytype TEXT NOT NULL,
+    -- the full key information, as a json-encoded dict
+    keydata TEXT NOT NULL,
+    -- for keeping the keys in order, so that we can fetch the latest one
+    stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id);
+
+-- cross-signing signatures
+CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures (
+    -- user who did the signing
+    user_id TEXT NOT NULL,
+    -- key used to sign
+    key_id TEXT NOT NULL,
+    -- user who was signed
+    target_user_id TEXT NOT NULL,
+    -- device/key that was signed
+    target_device_id TEXT NOT NULL,
+    -- the actual signature
+    signature TEXT NOT NULL
+);
+
+-- replaced by the index created in signing_keys_nonunique_signatures.sql
+-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
+
+-- stream of user signature updates
+CREATE TABLE IF NOT EXISTS user_signature_stream (
+    -- uses the same stream ID as device list stream
+    stream_id BIGINT NOT NULL,
+    -- user who did the signing
+    from_user_id TEXT NOT NULL,
+    -- list of users who were signed, as a JSON array
+    user_ids TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
new file mode 100644
index 0000000000..0aa90ebf0c
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* The cross-signing signatures index should not be a unique index, because a
+ * user may upload multiple signatures for the same target user. The previous
+ * index was unique, so delete it if it's there and create a new non-unique
+ * index. */
+
+DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT
+EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
diff --git a/synapse/storage/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
index 163529c071..bbdde121e8 100644
--- a/synapse/storage/schema/delta/56/stats_separated.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
@@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN (
     'populate_stats_cleanup'
 );
 
+-- this relies on current_state_events.membership having been populated, so add
+-- a dependency on current_state_events_membership.
 INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
-    ('populate_stats_process_rooms', '{}', '');
+    ('populate_stats_process_rooms', '{}', 'current_state_events_membership');
 
+-- this also relies on current_state_events.membership having been populated, but
+-- we get that as a side-effect of depending on populate_stats_process_rooms.
 INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
     ('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
 
diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
new file mode 100644
index 0000000000..1de8b54961
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
@@ -0,0 +1,52 @@
+import logging
+
+from synapse.storage.engines import PostgresEngine
+
+logger = logging.getLogger(__name__)
+
+
+"""
+This migration updates the user_filters table as follows:
+
+ - drops any (user_id, filter_id) duplicates
+ - makes the columns NON-NULLable
+ - turns the index into a UNIQUE index
+"""
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+    pass
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        select_clause = """
+            SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
+            FROM user_filters
+        """
+    else:
+        select_clause = """
+            SELECT * FROM user_filters GROUP BY user_id, filter_id
+        """
+    sql = """
+            DROP TABLE IF EXISTS user_filters_migration;
+            DROP INDEX IF EXISTS user_filters_unique;
+            CREATE TABLE user_filters_migration (
+                user_id TEXT NOT NULL,
+                filter_id BIGINT NOT NULL,
+                filter_json BYTEA NOT NULL
+            );
+            INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
+                %s;
+            CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
+                (user_id, filter_id);
+            DROP TABLE user_filters;
+            ALTER TABLE user_filters_migration RENAME TO user_filters;
+        """ % (
+        select_clause,
+    )
+
+    if isinstance(database_engine, PostgresEngine):
+        cur.execute(sql)
+    else:
+        cur.executescript(sql)
diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
new file mode 100644
index 0000000000..91390c4527
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * a table which records mappings from external auth providers to mxids
+ */
+CREATE TABLE IF NOT EXISTS user_external_ids (
+    auth_provider TEXT NOT NULL,
+    external_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    UNIQUE (auth_provider, external_id)
+);
diff --git a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
index 149f8be8b6..149f8be8b6 100644
--- a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
new file mode 100644
index 0000000000..aec06c8261
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add background update to go and delete current state events for rooms the
+-- server is no longer in.
+--
+-- this relies on the 'membership' column of current_state_events, so make sure
+-- that's populated first!
+INSERT into background_updates (update_name, progress_json, depends_on)
+    VALUES ('delete_old_current_state_events', '{}', 'current_state_events_membership');
diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
new file mode 100644
index 0000000000..c3b6de2099
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Records whether the server thinks that the remote users cached device lists
+-- may be out of date (e.g. if we have received a to device message from a
+-- device we don't know about).
+CREATE TABLE IF NOT EXISTS device_lists_remote_resync (
+    user_id TEXT NOT NULL,
+    added_ts BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync (user_id);
+CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync (added_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
new file mode 100644
index 0000000000..63b5acdcf7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# We create a new table called `local_current_membership` that stores the latest
+# membership state of local users in rooms, which helps track leaves/bans/etc
+# even if the server has left the room (and so has deleted the room from
+# `current_state_events`). This will also include outstanding invites for local
+# users for rooms the server isn't in.
+#
+# If the server isn't and hasn't been in the room then it will only include
+# outsstanding invites, and not e.g. pre-emptive bans of local users.
+#
+# If the server later rejoins a room `local_current_membership` can simply be
+# replaced with the new current state of the room (which results in the
+# equivalent behaviour as if the server had remained in the room).
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+    # We need to do the insert in `run_upgrade` section as we don't have access
+    # to `config` in `run_create`.
+
+    # This upgrade may take a bit of time for large servers (e.g. one minute for
+    # matrix.org) but means we avoid a lots of book keeping required to do it as
+    # a background update.
+
+    # We check if the `current_state_events.membership` is up to date by
+    # checking if the relevant background update has finished. If it has
+    # finished we can avoid doing a join against `room_memberships`, which
+    # speesd things up.
+    cur.execute(
+        """SELECT 1 FROM background_updates
+            WHERE update_name = 'current_state_events_membership'
+        """
+    )
+    current_state_membership_up_to_date = not bool(cur.fetchone())
+
+    # Cheekily drop and recreate indices, as that is faster.
+    cur.execute("DROP INDEX local_current_membership_idx")
+    cur.execute("DROP INDEX local_current_membership_room_idx")
+
+    if current_state_membership_up_to_date:
+        sql = """
+            INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+                SELECT c.room_id, state_key AS user_id, event_id, c.membership
+                FROM current_state_events AS c
+                WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ?
+        """
+    else:
+        # We can't rely on the membership column, so we need to join against
+        # `room_memberships`.
+        sql = """
+            INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+                SELECT c.room_id, state_key AS user_id, event_id, r.membership
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS r USING (event_id)
+                WHERE type = 'm.room.member' AND state_key LIKE ?
+        """
+    sql = database_engine.convert_param_style(sql)
+    cur.execute(sql, ("%:" + config.server_name,))
+
+    cur.execute(
+        "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+    )
+    cur.execute(
+        "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+    )
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    cur.execute(
+        """
+        CREATE TABLE local_current_membership (
+            room_id TEXT NOT NULL,
+            user_id TEXT NOT NULL,
+            event_id TEXT NOT NULL,
+            membership TEXT NOT NULL
+        )"""
+    )
+
+    cur.execute(
+        "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+    )
+    cur.execute(
+        "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+    )
diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
new file mode 100644
index 0000000000..133d80af35
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- we no longer keep sent outbound device pokes in the db; clear them out
+-- so that we don't have to worry about them.
+--
+-- This is a sequence scan, but it doesn't take too long.
+
+DELETE FROM device_lists_outbound_pokes WHERE sent;
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
new file mode 100644
index 0000000000..352a66f5b0
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We want to start storing the room version independently of
+-- `current_state_events` so that we can delete stale entries from it without
+-- losing the information.
+ALTER TABLE rooms ADD COLUMN room_version TEXT;
+
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('add_rooms_room_version_column', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
new file mode 100644
index 0000000000..c601cff6de
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
@@ -0,0 +1,35 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- when we first added the room_version column, it was populated via a background
+-- update. We now need it to be populated before synapse starts, so we populate
+-- any remaining rows with a NULL room version now. For servers which have completed
+-- the background update, this will be pretty quick.
+
+-- the following query will set room_version to NULL if no create event is found for
+-- the room in current_state_events, and will set it to '1' if a create event with no
+-- room_version is found.
+
+UPDATE rooms SET room_version=(
+    SELECT COALESCE(json::json->'content'->>'room_version','1')
+    FROM current_state_events cse INNER JOIN event_json ej USING (event_id)
+    WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key=''
+) WHERE rooms.room_version IS NULL;
+
+-- we still allow the background update to complete: it has the useful side-effect of
+-- populating `rooms` with any missing rooms (based on the current_state_events table).
+
+-- see also rooms_version_column_2.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
new file mode 100644
index 0000000000..335c6f2074
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_2.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+    SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+    FROM current_state_events cse INNER JOIN event_json ej USING (event_id)
+    WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key=''
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
new file mode 100644
index 0000000000..92aaadde0d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
@@ -0,0 +1,39 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- When we first added the room_version column to the rooms table, it was populated from
+-- the current_state_events table. However, there was an issue causing a background
+-- update to clean up the current_state_events table for rooms where the server is no
+-- longer participating, before that column could be populated. Therefore, some rooms had
+-- a NULL room_version.
+
+-- The rooms_version_column_2.sql.* delta files were introduced to make the populating
+-- synchronous instead of running it in a background update, which fixed this issue.
+-- However, all of the instances of Synapse installed or updated in the meantime got
+-- their rooms table corrupted with NULL room_versions.
+
+-- This query fishes out the room versions from the create event using the state_events
+-- table instead of the current_state_events one, as the former still have all of the
+-- create events.
+
+UPDATE rooms SET room_version=(
+    SELECT COALESCE(json::json->'content'->>'room_version','1')
+    FROM state_events se INNER JOIN event_json ej USING (event_id)
+    WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+    LIMIT 1
+) WHERE rooms.room_version IS NULL;
+
+-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
new file mode 100644
index 0000000000..e19dab97cb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
@@ -0,0 +1,23 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_3.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+    SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+    FROM state_events se INNER JOIN event_json ej USING (event_id)
+    WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+    LIMIT 1
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
new file mode 100644
index 0000000000..fdc39e9ba5
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ /* for some reason, we have accumulated duplicate entries in
+  * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+  * efficient.
+  */
+
+INSERT INTO background_updates (ordering, update_name, progress_json)
+    VALUES (5800, 'remove_dup_outbound_pokes', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
new file mode 100644
index 0000000000..dcb593fc2d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
@@ -0,0 +1,36 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions(
+    session_id TEXT NOT NULL,  -- The session ID passed to the client.
+    creation_time BIGINT NOT NULL,  -- The time this session was created (epoch time in milliseconds).
+    serverdict TEXT NOT NULL,  -- A JSON dictionary of arbitrary data added by Synapse.
+    clientdict TEXT NOT NULL,  -- A JSON dictionary of arbitrary data from the client.
+    uri TEXT NOT NULL,  -- The URI the UI authentication session is using.
+    method TEXT NOT NULL,  -- The HTTP method the UI authentication session is using.
+    -- The clientdict, uri, and method make up an tuple that must be immutable
+    -- throughout the lifetime of the UI Auth session.
+    description TEXT NOT NULL,  -- A human readable description of the operation which caused the UI Auth flow to occur.
+    UNIQUE (session_id)
+);
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
+    session_id TEXT NOT NULL,  -- The corresponding UI Auth session.
+    stage_type TEXT NOT NULL,  -- The stage type.
+    result TEXT NOT NULL,  -- The result of the stage verification, stored as JSON.
+    UNIQUE (session_id, stage_type),
+    FOREIGN KEY (session_id)
+        REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 0000000000..aa46eb0e10
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+    stream_id       BIGINT NOT NULL,
+    instance_name   TEXT NOT NULL,
+    cache_func      TEXT NOT NULL,
+    keys            TEXT[],
+    invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
new file mode 100644
index 0000000000..d353f2bcb3
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
@@ -0,0 +1,80 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration rebuilds the device_lists_outbound_last_success table without duplicate
+entries, and with a UNIQUE index.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(*args, **kwargs):
+    pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    # some instances might already have this index, in which case we can skip this
+    if isinstance(database_engine, PostgresEngine):
+        cur.execute(
+            """
+            SELECT 1 FROM pg_class WHERE relkind = 'i'
+            AND relname = 'device_lists_outbound_last_success_unique_idx'
+            """
+        )
+
+        if cur.rowcount:
+            logger.info(
+                "Unique index exists on device_lists_outbound_last_success: "
+                "skipping rebuild"
+            )
+            return
+
+    logger.info("Rebuilding device_lists_outbound_last_success with unique index")
+    execute_statements_from_stream(cur, StringIO(_rebuild_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_rebuild_commands = """
+DROP TABLE IF EXISTS device_lists_outbound_last_success_new;
+
+CREATE TABLE device_lists_outbound_last_success_new (
+    destination TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    stream_id BIGINT NOT NULL
+);
+
+-- this took about 30 seconds on matrix.org's 16 million rows.
+INSERT INTO device_lists_outbound_last_success_new
+    SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success
+    GROUP BY destination, user_id;
+
+-- and this another 30 seconds.
+CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx
+    ON device_lists_outbound_last_success_new (destination, user_id);
+
+DROP TABLE device_lists_outbound_last_success;
+
+ALTER TABLE device_lists_outbound_last_success_new
+    RENAME TO device_lists_outbound_last_success;
+"""
diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
index 883fcd10b2..883fcd10b2 100644
--- a/synapse/storage/schema/full_schemas/16/application_services.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
index 10ce2aa7a0..10ce2aa7a0 100644
--- a/synapse/storage/schema/full_schemas/16/event_edges.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
index 95826da431..95826da431 100644
--- a/synapse/storage/schema/full_schemas/16/event_signatures.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
index a1a2aa8e5b..a1a2aa8e5b 100644
--- a/synapse/storage/schema/full_schemas/16/im.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
index 11cdffdbb3..11cdffdbb3 100644
--- a/synapse/storage/schema/full_schemas/16/keys.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
index 8f3759bb2a..8f3759bb2a 100644
--- a/synapse/storage/schema/full_schemas/16/media_repository.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
index 01d2d8f833..01d2d8f833 100644
--- a/synapse/storage/schema/full_schemas/16/presence.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
index c04f4747d9..c04f4747d9 100644
--- a/synapse/storage/schema/full_schemas/16/profiles.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
index e44465cf45..e44465cf45 100644
--- a/synapse/storage/schema/full_schemas/16/push.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
index 318f0d9aa5..318f0d9aa5 100644
--- a/synapse/storage/schema/full_schemas/16/redactions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
index d47da3b12f..d47da3b12f 100644
--- a/synapse/storage/schema/full_schemas/16/room_aliases.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
index 96391a8f0e..96391a8f0e 100644
--- a/synapse/storage/schema/full_schemas/16/state.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
index 17e67bedac..17e67bedac 100644
--- a/synapse/storage/schema/full_schemas/16/transactions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
index f013aa8b18..f013aa8b18 100644
--- a/synapse/storage/schema/full_schemas/16/users.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 098434356f..889a9a0ce4 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -70,15 +70,6 @@ CREATE TABLE appservice_stream_position (
 );
 
 
-
-CREATE TABLE background_updates (
-    update_name text NOT NULL,
-    progress_json text NOT NULL,
-    depends_on text
-);
-
-
-
 CREATE TABLE blocked_rooms (
     room_id text NOT NULL,
     user_id text NOT NULL
@@ -984,40 +975,6 @@ CREATE TABLE state_events (
 
 
 
-CREATE TABLE state_group_edges (
-    state_group bigint NOT NULL,
-    prev_state_group bigint NOT NULL
-);
-
-
-
-CREATE SEQUENCE state_group_id_seq
-    START WITH 1
-    INCREMENT BY 1
-    NO MINVALUE
-    NO MAXVALUE
-    CACHE 1;
-
-
-
-CREATE TABLE state_groups (
-    id bigint NOT NULL,
-    room_id text NOT NULL,
-    event_id text NOT NULL
-);
-
-
-
-CREATE TABLE state_groups_state (
-    state_group bigint NOT NULL,
-    room_id text NOT NULL,
-    type text NOT NULL,
-    state_key text NOT NULL,
-    event_id text NOT NULL
-);
-
-
-
 CREATE TABLE stats_stream_pos (
     lock character(1) DEFAULT 'X'::bpchar NOT NULL,
     stream_id bigint,
@@ -1202,11 +1159,6 @@ ALTER TABLE ONLY appservice_stream_position
 
 
 
-ALTER TABLE ONLY background_updates
-    ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name);
-
-
-
 ALTER TABLE ONLY current_state_events
     ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id);
 
@@ -1496,12 +1448,6 @@ ALTER TABLE ONLY state_events
     ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id);
 
 
-
-ALTER TABLE ONLY state_groups
-    ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id);
-
-
-
 ALTER TABLE ONLY stats_stream_pos
     ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock);
 
@@ -1942,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts);
 
 
 
-CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group);
-
-
-
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group);
-
-
-
-CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key);
-
-
-
 CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering);
 
 
@@ -2047,6 +1981,3 @@ CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_room
 
 
 CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id);
-
-
-
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index be9295e4c9..a0411ede7e 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id);
 CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
 CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) );
 CREATE INDEX room_depth_room ON room_depth(room_id);
-CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
-CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL );
 CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) );
 CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) );
 CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) );
@@ -67,7 +65,6 @@ CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
 CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
 CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) );
 CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
-CREATE TABLE background_updates( update_name TEXT NOT NULL, progress_json TEXT NOT NULL, depends_on TEXT, CONSTRAINT background_updates_uniqueness UNIQUE (update_name) );
 CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
 /* event_search(event_id,room_id,sender,"key",value) */;
 CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
@@ -121,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL );
 CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT);
 CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id );
 CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id );
-CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL );
-CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group);
 CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
 CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering );
 CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering );
@@ -255,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen);
 CREATE INDEX users_creation_ts ON users (creation_ts);
 CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group);
 CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id);
-CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key);
 CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id);
 CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip);
diff --git a/synapse/storage/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
index c265fd20e2..91d21b2921 100644
--- a/synapse/storage/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
@@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales
 INSERT INTO user_directory_stream_pos (stream_id) VALUES (0);
 INSERT INTO stats_stream_pos (stream_id) VALUES (0);
 INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
+-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..c00f287190
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,21 @@
+# Synapse Database Schemas
+
+These schemas are used as a basis to create brand new Synapse databases, on both
+SQLite3 and Postgres.
+
+## Building full schema dumps
+
+If you want to recreate these schemas, they need to be made from a database that
+has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
+`full.sql.postgres ` and `full.sql.sqlite` files. 
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`, then call
+
+    ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+
+There are currently two folders with full-schema snapshots. `16` is a snapshot
+from 2015, for historical reference. The other contains the most recent full
+schema snapshot.
diff --git a/synapse/storage/search.py b/synapse/storage/data_stores/main/search.py
index df87ab6a6d..13f49d8060 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -24,10 +24,11 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-from .background_updates import BackgroundUpdateStore
-
 logger = logging.getLogger(__name__)
 
 SearchEntry = namedtuple(
@@ -36,23 +37,71 @@ SearchEntry = namedtuple(
 )
 
 
-class SearchStore(BackgroundUpdateStore):
+class SearchWorkerStore(SQLBaseStore):
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if not self.hs.config.enable_search:
+            return
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = (
+                (
+                    entry.event_id,
+                    entry.room_id,
+                    entry.key,
+                    entry.value,
+                    entry.stream_ordering,
+                    entry.origin_server_ts,
+                )
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = (
+                (entry.event_id, entry.room_id, entry.key, entry.value)
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+
+class SearchBackgroundUpdateStore(SearchWorkerStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, db_conn, hs):
-        super(SearchStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         if not hs.config.enable_search:
             return
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
         )
 
@@ -61,9 +110,11 @@ class SearchStore(BackgroundUpdateStore):
         # a GIN index. However, it's possible that some people might still have
         # the background update queued, so we register a handler to clear the
         # background update.
-        self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+        self.db.updates.register_noop_background_update(
+            self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
+        )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
         )
 
@@ -93,7 +144,7 @@ class SearchStore(BackgroundUpdateStore):
             # store_search_entries_txn with a generator function, but that
             # would mean having two cursors open on the database at once.
             # Instead we just build a list of results.
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return 0
 
@@ -153,18 +204,18 @@ class SearchStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(event_search_rows),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_SEARCH_UPDATE_NAME, progress
             )
 
             return len(event_search_rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
+            yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
 
         return result
 
@@ -196,7 +247,7 @@ class SearchStore(BackgroundUpdateStore):
                         " ON event_search USING GIN (vector)"
                     )
                 except psycopg2.ProgrammingError as e:
-                    logger.warn(
+                    logger.warning(
                         "Ignoring error %r when trying to switch from GIST to GIN", e
                     )
 
@@ -206,9 +257,11 @@ class SearchStore(BackgroundUpdateStore):
                 conn.set_session(autocommit=False)
 
         if isinstance(self.database_engine, PostgresEngine):
-            yield self.runWithConnection(create_index)
+            yield self.db.runWithConnection(create_index)
 
-        yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
+        yield self.db.updates._end_background_update(
+            self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
+        )
         return 1
 
     @defer.inlineCallbacks
@@ -237,14 +290,14 @@ class SearchStore(BackgroundUpdateStore):
                 )
                 conn.set_session(autocommit=False)
 
-            yield self.runWithConnection(create_index)
+            yield self.db.runWithConnection(create_index)
 
             pg = dict(progress)
             pg["have_added_indexes"] = True
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
                 pg,
             )
@@ -274,89 +327,27 @@ class SearchStore(BackgroundUpdateStore):
                 "have_added_indexes": True,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
             )
 
             return len(rows), True
 
-        num_rows, finished = yield self.runInteraction(
+        num_rows, finished = yield self.db.runInteraction(
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
         )
 
         if not finished:
-            yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
-
-        return num_rows
-
-    def store_event_search_txn(self, txn, event, key, value):
-        """Add event to the search table
-
-        Args:
-            txn (cursor):
-            event (EventBase):
-            key (str):
-            value (str):
-        """
-        self.store_search_entries_txn(
-            txn,
-            (
-                SearchEntry(
-                    key=key,
-                    value=value,
-                    event_id=event.event_id,
-                    room_id=event.room_id,
-                    stream_ordering=event.internal_metadata.stream_ordering,
-                    origin_server_ts=event.origin_server_ts,
-                ),
-            ),
-        )
-
-    def store_search_entries_txn(self, txn, entries):
-        """Add entries to the search table
-
-        Args:
-            txn (cursor):
-            entries (iterable[SearchEntry]):
-                entries to be added to the table
-        """
-        if not self.hs.config.enable_search:
-            return
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            yield self.db.updates._end_background_update(
+                self.EVENT_SEARCH_ORDER_UPDATE_NAME
             )
 
-            args = (
-                (
-                    entry.event_id,
-                    entry.room_id,
-                    entry.key,
-                    entry.value,
-                    entry.stream_ordering,
-                    entry.origin_server_ts,
-                )
-                for entry in entries
-            )
-
-            txn.executemany(sql, args)
+        return num_rows
 
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            args = (
-                (entry.event_id, entry.room_id, entry.key, entry.value)
-                for entry in entries
-            )
 
-            txn.executemany(sql, args)
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
+class SearchStore(SearchBackgroundUpdateStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchStore, self).__init__(database, db_conn, hs)
 
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
@@ -373,15 +364,17 @@ class SearchStore(BackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = search_query = _parse_query(self.database_engine, search_term)
+        search_query = _parse_query(self.database_engine, search_term)
 
         args = []
 
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
         if len(room_ids) < 500:
-            clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
-            args.extend(room_ids)
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", room_ids
+            )
+            clauses = [clause]
 
         local_clauses = []
         for key in keys:
@@ -434,11 +427,18 @@ class SearchStore(BackgroundUpdateStore):
         # entire table from the database.
         sql += " ORDER BY rank DESC LIMIT 500"
 
-        results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
+        results = yield self.db.execute(
+            "search_msgs", self.db.cursor_to_dict, sql, *args
+        )
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
@@ -448,8 +448,8 @@ class SearchStore(BackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self._execute(
-            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        count_results = yield self.db.execute(
+            "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -480,15 +480,17 @@ class SearchStore(BackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = search_query = _parse_query(self.database_engine, search_term)
+        search_query = _parse_query(self.database_engine, search_term)
 
         args = []
 
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
         if len(room_ids) < 500:
-            clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
-            args.extend(room_ids)
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", room_ids
+            )
+            clauses = [clause]
 
         local_clauses = []
         for key in keys:
@@ -577,11 +579,18 @@ class SearchStore(BackgroundUpdateStore):
 
         args.append(limit)
 
-        results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
+        results = yield self.db.execute(
+            "search_rooms", self.db.cursor_to_dict, sql, *args
+        )
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
@@ -591,8 +600,8 @@ class SearchStore(BackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self._execute(
-            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        count_results = yield self.db.execute(
+            "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -663,7 +672,7 @@ class SearchStore(BackgroundUpdateStore):
                     )
                 )
                 txn.execute(query, (value, search_query))
-                headline, = txn.fetchall()[0]
+                (headline,) = txn.fetchall()[0]
 
                 # Now we need to pick the possible highlights out of the haedline
                 # result.
@@ -677,7 +686,7 @@ class SearchStore(BackgroundUpdateStore):
 
             return highlight_words
 
-        return self.runInteraction("_find_highlights", f)
+        return self.db.runInteraction("_find_highlights", f)
 
 
 def _to_postgres_options(options_dict):
diff --git a/synapse/storage/signatures.py b/synapse/storage/data_stores/main/signatures.py
index fb83218f90..36244d9f5d 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -13,24 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import six
-
 from unpaddedbase64 import encode_base64
 
 from twisted.internet import defer
 
-from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
-from ._base import SQLBaseStore
-
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
-
 
 class SignatureWorkerStore(SQLBaseStore):
     @cached()
@@ -49,7 +38,7 @@ class SignatureWorkerStore(SQLBaseStore):
                 for event_id in event_ids
             }
 
-        return self.runInteraction("get_event_reference_hashes", f)
+        return self.db.runInteraction("get_event_reference_hashes", f)
 
     @defer.inlineCallbacks
     def add_event_hashes(self, event_ids):
@@ -80,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore):
 
 class SignatureStore(SignatureWorkerStore):
     """Persistence for event signatures and hashes"""
-
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append(
-                {
-                    "event_id": event.event_id,
-                    "algorithm": ref_alg,
-                    "hash": db_binary_type(ref_hash_bytes),
-                }
-            )
-
-        self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
new file mode 100644
index 0000000000..347cc50778
--- /dev/null
+++ b/synapse/storage/data_stores/main/state.py
@@ -0,0 +1,467 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections.abc
+import logging
+from collections import namedtuple
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedList
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+    """Return type of get_state_group_delta that implements __len__, which lets
+    us use the itrable flag when caching
+    """
+
+    __slots__ = []
+
+    def __len__(self):
+        return len(self.delta_ids) if self.delta_ids else 0
+
+
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
+    """The parts of StateGroupStore that can be called from workers.
+    """
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+
+    async def get_room_version(self, room_id: str) -> RoomVersion:
+        """Get the room_version of a given room
+
+        Raises:
+            NotFoundError: if the room is unknown
+
+            UnsupportedRoomVersionError: if the room uses an unknown room version.
+                Typically this happens if support for the room's version has been
+                removed from Synapse.
+        """
+        room_version_id = await self.get_room_version_id(room_id)
+        v = KNOWN_ROOM_VERSIONS.get(room_version_id)
+
+        if not v:
+            raise UnsupportedRoomVersionError(
+                "Room %s uses a room version %s which is no longer supported"
+                % (room_id, room_version_id)
+            )
+
+        return v
+
+    @cached(max_entries=10000)
+    async def get_room_version_id(self, room_id: str) -> str:
+        """Get the room_version of a given room
+
+        Raises:
+            NotFoundError: if the room is unknown
+        """
+
+        # First we try looking up room version from the database, but for old
+        # rooms we might not have added the room version to it yet so we fall
+        # back to previous behaviour and look in current state events.
+
+        # We really should have an entry in the rooms table for every room we
+        # care about, but let's be a bit paranoid (at least while the background
+        # update is happening) to avoid breaking existing rooms.
+        version = await self.db.simple_select_one_onecol(
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcol="room_version",
+            desc="get_room_version",
+            allow_none=True,
+        )
+
+        if version is not None:
+            return version
+
+        # Retrieve the room's create event
+        create_event = await self.get_create_event_for_room(room_id)
+        return create_event.content.get("room_version", "1")
+
+    @defer.inlineCallbacks
+    def get_room_predecessor(self, room_id):
+        """Get the predecessor of an upgraded room if it exists.
+        Otherwise return None.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[dict|None]: A dictionary containing the structure of the predecessor
+                field from the room's create event. The structure is subject to other servers,
+                but it is expected to be:
+                    * room_id (str): The room ID of the predecessor room
+                    * event_id (str): The ID of the tombstone event in the predecessor room
+
+                None if a predecessor key is not found, or is not a dictionary.
+
+        Raises:
+            NotFoundError if the given room is unknown
+        """
+        # Retrieve the room's create event
+        create_event = yield self.get_create_event_for_room(room_id)
+
+        # Retrieve the predecessor key of the create event
+        predecessor = create_event.content.get("predecessor", None)
+
+        # Ensure the key is a dictionary
+        if not isinstance(predecessor, collections.abc.Mapping):
+            return None
+
+        return predecessor
+
+    @defer.inlineCallbacks
+    def get_create_event_for_room(self, room_id):
+        """Get the create state event for a room.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[EventBase]: The room creation event.
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
+        state_ids = yield self.get_current_state_ids(room_id)
+        create_id = state_ids.get((EventTypes.Create, ""))
+
+        # If we can't find the create event, assume we've hit a dead end
+        if not create_id:
+            raise NotFoundError("Unknown room %s" % (room_id,))
+
+        # Retrieve the room's create event and return
+        create_event = yield self.get_event(create_id)
+        return create_event
+
+    @cached(max_entries=100000, iterable=True)
+    def get_current_state_ids(self, room_id):
+        """Get the current state event ids for a room based on the
+        current_state_events table.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            deferred: dict of (type, state_key) -> event_id
+        """
+
+        def _get_current_state_ids_txn(txn):
+            txn.execute(
+                """SELECT type, state_key, event_id FROM current_state_events
+                WHERE room_id = ?
+                """,
+                (room_id,),
+            )
+
+            return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
+
+        return self.db.runInteraction(
+            "get_current_state_ids", _get_current_state_ids_txn
+        )
+
+    # FIXME: how should this be cached?
+    def get_filtered_current_state_ids(
+        self, room_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
+        """Get the current state event of a given type for a room based on the
+        current_state_events table.  This may not be as up-to-date as the result
+        of doing a fresh state resolution as per state_handler.get_current_state
+
+        Args:
+            room_id
+            state_filter: The state filter used to fetch state
+                from the database.
+
+        Returns:
+            defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+        """
+
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        if not where_clause:
+            # We delegate to the cached version
+            return self.get_current_state_ids(room_id)
+
+        def _get_filtered_current_state_ids_txn(txn):
+            results = {}
+            sql = """
+                SELECT type, state_key, event_id FROM current_state_events
+                WHERE room_id = ?
+            """
+
+            if where_clause:
+                sql += " AND (%s)" % (where_clause,)
+
+            args = [room_id]
+            args.extend(where_args)
+            txn.execute(sql, args)
+            for row in txn:
+                typ, state_key, event_id = row
+                key = (intern_string(typ), intern_string(state_key))
+                results[key] = event_id
+
+            return results
+
+        return self.db.runInteraction(
+            "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
+        )
+
+    @defer.inlineCallbacks
+    def get_canonical_alias_for_room(self, room_id):
+        """Get canonical alias for room, if any
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[str|None]: The canonical alias, if any
+        """
+
+        state = yield self.get_filtered_current_state_ids(
+            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+        )
+
+        event_id = state.get((EventTypes.CanonicalAlias, ""))
+        if not event_id:
+            return
+
+        event = yield self.get_event(event_id, allow_none=True)
+        if not event:
+            return
+
+        return event.content.get("canonical_alias")
+
+    @cached(max_entries=50000)
+    def _get_state_group_for_event(self, event_id):
+        return self.db.simple_select_one_onecol(
+            table="event_to_state_groups",
+            keyvalues={"event_id": event_id},
+            retcol="state_group",
+            allow_none=True,
+            desc="_get_state_group_for_event",
+        )
+
+    @cachedList(
+        cached_method_name="_get_state_group_for_event",
+        list_name="event_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def _get_state_group_for_events(self, event_ids):
+        """Returns mapping event_id -> state_group
+        """
+        rows = yield self.db.simple_select_many_batch(
+            table="event_to_state_groups",
+            column="event_id",
+            iterable=event_ids,
+            keyvalues={},
+            retcols=("event_id", "state_group"),
+            desc="_get_state_group_for_events",
+        )
+
+        return {row["event_id"]: row["state_group"] for row in rows}
+
+    @defer.inlineCallbacks
+    def get_referenced_state_groups(self, state_groups):
+        """Check if the state groups are referenced by events.
+
+        Args:
+            state_groups (Iterable[int])
+
+        Returns:
+            Deferred[set[int]]: The subset of state groups that are
+            referenced.
+        """
+
+        rows = yield self.db.simple_select_many_batch(
+            table="event_to_state_groups",
+            column="state_group",
+            iterable=state_groups,
+            keyvalues={},
+            retcols=("DISTINCT state_group",),
+            desc="get_referenced_state_groups",
+        )
+
+        return {row["state_group"] for row in rows}
+
+
+class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
+
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+    EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
+    DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+        self.server_name = hs.hostname
+
+        self.db.updates.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+        self.db.updates.register_background_index_update(
+            self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
+            index_name="event_to_state_groups_sg_index",
+            table="event_to_state_groups",
+            columns=["state_group"],
+        )
+        self.db.updates.register_background_update_handler(
+            self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
+        )
+
+    async def _background_remove_left_rooms(self, progress, batch_size):
+        """Background update to delete rows from `current_state_events` and
+        `event_forward_extremities` tables of rooms that the server is no
+        longer joined to.
+        """
+
+        last_room_id = progress.get("last_room_id", "")
+
+        def _background_remove_left_rooms_txn(txn):
+            sql = """
+                SELECT DISTINCT room_id FROM current_state_events
+                WHERE room_id > ? ORDER BY room_id LIMIT ?
+            """
+
+            txn.execute(sql, (last_room_id, batch_size))
+            room_ids = [row[0] for row in txn]
+            if not room_ids:
+                return True, set()
+
+            sql = """
+                SELECT room_id
+                FROM current_state_events
+                WHERE
+                    room_id > ? AND room_id <= ?
+                    AND type = 'm.room.member'
+                    AND membership = 'join'
+                    AND state_key LIKE ?
+                GROUP BY room_id
+            """
+
+            txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
+
+            joined_room_ids = {row[0] for row in txn}
+
+            left_rooms = set(room_ids) - joined_room_ids
+
+            logger.info("Deleting current state left rooms: %r", left_rooms)
+
+            # First we get all users that we still think were joined to the
+            # room. This is so that we can mark those device lists as
+            # potentially stale, since there may have been a period where the
+            # server didn't share a room with the remote user and therefore may
+            # have missed any device updates.
+            rows = self.db.simple_select_many_txn(
+                txn,
+                table="current_state_events",
+                column="room_id",
+                iterable=left_rooms,
+                keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
+                retcols=("state_key",),
+            )
+
+            potentially_left_users = {row["state_key"] for row in rows}
+
+            # Now lets actually delete the rooms from the DB.
+            self.db.simple_delete_many_txn(
+                txn,
+                table="current_state_events",
+                column="room_id",
+                iterable=left_rooms,
+                keyvalues={},
+            )
+
+            self.db.simple_delete_many_txn(
+                txn,
+                table="event_forward_extremities",
+                column="room_id",
+                iterable=left_rooms,
+                keyvalues={},
+            )
+
+            self.db.updates._background_update_progress_txn(
+                txn,
+                self.DELETE_CURRENT_STATE_UPDATE_NAME,
+                {"last_room_id": room_ids[-1]},
+            )
+
+            return False, potentially_left_users
+
+        finished, potentially_left_users = await self.db.runInteraction(
+            "_background_remove_left_rooms", _background_remove_left_rooms_txn
+        )
+
+        if finished:
+            await self.db.updates._end_background_update(
+                self.DELETE_CURRENT_STATE_UPDATE_NAME
+            )
+
+        # Now go and check if we still share a room with the remote users in
+        # the deleted rooms. If not mark their device lists as stale.
+        joined_users = await self.get_users_server_still_shares_room_with(
+            potentially_left_users
+        )
+
+        for user_id in potentially_left_users - joined_users:
+            await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+        return batch_size
+
+
+class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
+    """ Keeps track of the state at a given event.
+
+    This is done by the concept of `state groups`. Every event is a assigned
+    a state group (identified by an arbitrary string), which references a
+    collection of state events. The current state of an event is then the
+    collection of state events referenced by the event's state group.
+
+    Hence, every change in the current state causes a new state group to be
+    generated. However, if no change happens (e.g., if we get a message event
+    with only one parent it inherits the state group from its parent.)
+
+    There are three tables:
+      * `state_groups`: Stores group name, first event with in the group and
+        room id.
+      * `event_to_state_groups`: Maps events to state groups.
+      * `state_groups_state`: Maps state group to state events.
+    """
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateStore, self).__init__(database, db_conn, hs)
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 5fdb442104..725e12507f 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -15,13 +15,15 @@
 
 import logging
 
+from twisted.internet import defer
+
 from synapse.storage._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
 
 class StateDeltasStore(SQLBaseStore):
-    def get_current_state_deltas(self, prev_stream_id):
+    def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
         """Fetch a list of room state changes since the given stream id
 
         Each entry in the result contains the following fields:
@@ -36,15 +38,27 @@ class StateDeltasStore(SQLBaseStore):
 
         Args:
             prev_stream_id (int): point to get changes since (exclusive)
+            max_stream_id (int): the point that we know has been correctly persisted
+               - ie, an upper limit to return changes from.
 
         Returns:
-            Deferred[list[dict]]: results
+            Deferred[tuple[int, list[dict]]: A tuple consisting of:
+               - the stream id which these results go up to
+               - list of current_state_delta_stream rows. If it is empty, we are
+                 up to date.
         """
         prev_stream_id = int(prev_stream_id)
+
+        # check we're not going backwards
+        assert prev_stream_id <= max_stream_id
+
         if not self._curr_state_delta_stream_cache.has_any_entity_changed(
             prev_stream_id
         ):
-            return []
+            # if the CSDs haven't changed between prev_stream_id and now, we
+            # know for certain that they haven't changed between prev_stream_id and
+            # max_stream_id.
+            return defer.succeed((max_stream_id, []))
 
         def get_current_state_deltas_txn(txn):
             # First we calculate the max stream id that will give us less than
@@ -54,21 +68,29 @@ class StateDeltasStore(SQLBaseStore):
             sql = """
                 SELECT stream_id, count(*)
                 FROM current_state_delta_stream
-                WHERE stream_id > ?
+                WHERE stream_id > ? AND stream_id <= ?
                 GROUP BY stream_id
                 ORDER BY stream_id ASC
                 LIMIT 100
             """
-            txn.execute(sql, (prev_stream_id,))
+            txn.execute(sql, (prev_stream_id, max_stream_id))
 
             total = 0
-            max_stream_id = prev_stream_id
-            for max_stream_id, count in txn:
+
+            for stream_id, count in txn:
                 total += count
                 if total > 100:
                     # We arbitarily limit to 100 entries to ensure we don't
                     # select toooo many.
+                    logger.debug(
+                        "Clipping current_state_delta_stream rows to stream_id %i",
+                        stream_id,
+                    )
+                    clipped_stream_id = stream_id
                     break
+            else:
+                # if there's no problem, we may as well go right up to the max_stream_id
+                clipped_stream_id = max_stream_id
 
             # Now actually get the deltas
             sql = """
@@ -77,15 +99,15 @@ class StateDeltasStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
             """
-            txn.execute(sql, (prev_stream_id, max_stream_id))
-            return self.cursor_to_dict(txn)
+            txn.execute(sql, (prev_stream_id, clipped_stream_id))
+            return clipped_stream_id, self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn
         )
 
     def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
-        return self._simple_select_one_onecol_txn(
+        return self.db.simple_select_one_onecol_txn(
             txn,
             table="current_state_delta_stream",
             keyvalues={},
@@ -93,7 +115,7 @@ class StateDeltasStore(SQLBaseStore):
         )
 
     def get_max_stream_id_in_current_state_deltas(self):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_max_stream_id_in_current_state_deltas",
             self._get_max_stream_id_in_current_state_deltas_txn,
         )
diff --git a/synapse/storage/stats.py b/synapse/storage/data_stores/main/stats.py
index 09190d684e..380c1ec7da 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -21,8 +21,9 @@ from twisted.internet import defer
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.storage import PostgresEngine
-from synapse.storage.state_deltas import StateDeltasStore
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, db_conn, hs):
-        super(StatsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StatsStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
@@ -68,17 +69,17 @@ class StatsStore(StateDeltasStore):
 
         self.stats_delta_processing_lock = DeferredLock()
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_stats_process_rooms", self._populate_stats_process_rooms
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_stats_process_users", self._populate_stats_process_users
         )
         # we no longer need to perform clean-up, but we will give ourselves
         # the potential to reintroduce it in the future – so documentation
         # will still encourage the use of this no-op handler.
-        self.register_noop_background_update("populate_stats_cleanup")
-        self.register_noop_background_update("populate_stats_prepare")
+        self.db.updates.register_noop_background_update("populate_stats_cleanup")
+        self.db.updates.register_noop_background_update("populate_stats_prepare")
 
     def quantise_stats_time(self, ts):
         """
@@ -102,7 +103,7 @@ class StatsStore(StateDeltasStore):
         This is a background update which regenerates statistics for users.
         """
         if not self.stats_enabled:
-            yield self._end_background_update("populate_stats_process_users")
+            yield self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         last_user_id = progress.get("last_user_id", "")
@@ -117,22 +118,22 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_user_id, batch_size))
             return [r for r, in txn]
 
-        users_to_work_on = yield self.runInteraction(
+        users_to_work_on = yield self.db.runInteraction(
             "_populate_stats_process_users", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not users_to_work_on:
-            yield self._end_background_update("populate_stats_process_users")
+            yield self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         for user_id in users_to_work_on:
             yield self._calculate_and_set_initial_state_for_user(user_id)
             progress["last_user_id"] = user_id
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_stats_process_users",
-            self._background_update_progress_txn,
+            self.db.updates._background_update_progress_txn,
             "populate_stats_process_users",
             progress,
         )
@@ -145,7 +146,7 @@ class StatsStore(StateDeltasStore):
         This is a background update which regenerates statistics for rooms.
         """
         if not self.stats_enabled:
-            yield self._end_background_update("populate_stats_process_rooms")
+            yield self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         last_room_id = progress.get("last_room_id", "")
@@ -160,22 +161,22 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_room_id, batch_size))
             return [r for r, in txn]
 
-        rooms_to_work_on = yield self.runInteraction(
+        rooms_to_work_on = yield self.db.runInteraction(
             "populate_stats_rooms_get_batch", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self._end_background_update("populate_stats_process_rooms")
+            yield self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         for room_id in rooms_to_work_on:
             yield self._calculate_and_set_initial_state_for_room(room_id)
             progress["last_room_id"] = room_id
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "_populate_stats_process_rooms",
-            self._background_update_progress_txn,
+            self.db.updates._background_update_progress_txn,
             "populate_stats_process_rooms",
             progress,
         )
@@ -186,7 +187,7 @@ class StatsStore(StateDeltasStore):
         """
         Returns the stats processor positions.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="stats_incremental_position",
             keyvalues={},
             retcol="stream_id",
@@ -215,7 +216,7 @@ class StatsStore(StateDeltasStore):
             if field and "\0" in field:
                 fields[col] = None
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="room_stats_state",
             keyvalues={"room_id": room_id},
             values=fields,
@@ -236,7 +237,7 @@ class StatsStore(StateDeltasStore):
             Deferred[list[dict]], where the dict has the keys of
             ABSOLUTE_STATS_FIELDS[stats_type],  and "bucket_size" and "end_ts".
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_statistics_for_subject",
             self._get_statistics_for_subject_txn,
             stats_type,
@@ -257,44 +258,19 @@ class StatsStore(StateDeltasStore):
             ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
         )
 
-        slice_list = self._simple_select_list_paginate_txn(
+        slice_list = self.db.simple_select_list_paginate_txn(
             txn,
             table + "_historical",
-            {id_col: stats_id},
             "end_ts",
             start,
             size,
             retcols=selected_columns + ["bucket_size", "end_ts"],
+            keyvalues={id_col: stats_id},
             order_direction="DESC",
         )
 
         return slice_list
 
-    def get_room_stats_state(self, room_id):
-        """
-        Returns the current room_stats_state for a room.
-
-        Args:
-            room_id (str): The ID of the room to return state for.
-
-        Returns (dict):
-            Dictionary containing these keys:
-                "name", "topic", "canonical_alias", "avatar", "join_rules",
-                "history_visibility"
-        """
-        return self._simple_select_one(
-            "room_stats_state",
-            {"room_id": room_id},
-            retcols=(
-                "name",
-                "topic",
-                "canonical_alias",
-                "avatar",
-                "join_rules",
-                "history_visibility",
-            ),
-        )
-
     @cached()
     def get_earliest_token_for_stats(self, stats_type, id):
         """
@@ -308,7 +284,7 @@ class StatsStore(StateDeltasStore):
         """
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             "%s_current" % (table,),
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
@@ -332,6 +308,9 @@ class StatsStore(StateDeltasStore):
         def _bulk_update_stats_delta_txn(txn):
             for stats_type, stats_updates in updates.items():
                 for stats_id, fields in stats_updates.items():
+                    logger.debug(
+                        "Updating %s stats for %s: %s", stats_type, stats_id, fields
+                    )
                     self._update_stats_delta_txn(
                         txn,
                         ts=ts,
@@ -341,14 +320,14 @@ class StatsStore(StateDeltasStore):
                         complete_with_stream_id=stream_id,
                     )
 
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="stats_incremental_position",
                 keyvalues={},
                 updatevalues={"stream_id": stream_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "bulk_update_stats_delta", _bulk_update_stats_delta_txn
         )
 
@@ -379,7 +358,7 @@ class StatsStore(StateDeltasStore):
                 Does not work with per-slice fields.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_stats_delta",
             self._update_stats_delta_txn,
             ts,
@@ -514,17 +493,17 @@ class StatsStore(StateDeltasStore):
         else:
             self.database_engine.lock_table(txn, table)
             retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
-            current_row = self._simple_select_one_txn(
+            current_row = self.db.simple_select_one_txn(
                 txn, table, keyvalues, retcols, allow_none=True
             )
             if current_row is None:
                 merged_dict = {**keyvalues, **absolutes, **additive_relatives}
-                self._simple_insert_txn(txn, table, merged_dict)
+                self.db.simple_insert_txn(txn, table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():
                     current_row[key] += val
                 current_row.update(absolutes)
-                self._simple_update_one_txn(txn, table, keyvalues, current_row)
+                self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
 
     def _upsert_copy_from_table_with_additive_relatives_txn(
         self,
@@ -611,11 +590,11 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, qargs)
         else:
             self.database_engine.lock_table(txn, into_table)
-            src_row = self._simple_select_one_txn(
+            src_row = self.db.simple_select_one_txn(
                 txn, src_table, keyvalues, copy_columns
             )
             all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
-            dest_current_row = self._simple_select_one_txn(
+            dest_current_row = self.db.simple_select_one_txn(
                 txn,
                 into_table,
                 keyvalues=all_dest_keyvalues,
@@ -631,11 +610,11 @@ class StatsStore(StateDeltasStore):
                     **src_row,
                     **additive_relatives,
                 }
-                self._simple_insert_txn(txn, into_table, merged_dict)
+                self.db.simple_insert_txn(txn, into_table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():
                     src_row[key] = dest_current_row[key] + val
-                self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+                self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
 
     def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
         """Fetches the counts of events in the given range of stream IDs.
@@ -649,7 +628,7 @@ class StatsStore(StateDeltasStore):
             changes.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "stats_incremental_total_events_and_bytes",
             self.get_changes_room_total_events_and_bytes_txn,
             min_pos,
@@ -732,7 +711,7 @@ class StatsStore(StateDeltasStore):
         def _fetch_current_state_stats(txn):
             pos = self.get_room_max_stream_ordering()
 
-            rows = self._simple_select_many_txn(
+            rows = self.db.simple_select_many_txn(
                 txn,
                 table="current_state_events",
                 column="type",
@@ -740,7 +719,7 @@ class StatsStore(StateDeltasStore):
                     EventTypes.Create,
                     EventTypes.JoinRules,
                     EventTypes.RoomHistoryVisibility,
-                    EventTypes.Encryption,
+                    EventTypes.RoomEncryption,
                     EventTypes.Name,
                     EventTypes.Topic,
                     EventTypes.RoomAvatar,
@@ -770,7 +749,7 @@ class StatsStore(StateDeltasStore):
                 (room_id,),
             )
 
-            current_state_events_count, = txn.fetchone()
+            (current_state_events_count,) = txn.fetchone()
 
             users_in_room = self.get_users_in_room_txn(txn, room_id)
 
@@ -788,7 +767,7 @@ class StatsStore(StateDeltasStore):
             current_state_events_count,
             users_in_room,
             pos,
-        ) = yield self.runInteraction(
+        ) = yield self.db.runInteraction(
             "get_initial_state_for_room", _fetch_current_state_stats
         )
 
@@ -812,7 +791,7 @@ class StatsStore(StateDeltasStore):
                 room_state["history_visibility"] = event.content.get(
                     "history_visibility"
                 )
-            elif event.type == EventTypes.Encryption:
+            elif event.type == EventTypes.RoomEncryption:
                 room_state["encryption"] = event.content.get("algorithm")
             elif event.type == EventTypes.Name:
                 room_state["name"] = event.content.get("name")
@@ -860,10 +839,10 @@ class StatsStore(StateDeltasStore):
                 """,
                 (user_id,),
             )
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count, pos
 
-        joined_rooms, pos = yield self.runInteraction(
+        joined_rooms, pos = yield self.db.runInteraction(
             "calculate_and_set_initial_state_for_user",
             _calculate_and_set_initial_state_for_user_txn,
         )
diff --git a/synapse/storage/stream.py b/synapse/storage/data_stores/main/stream.py
index 490454f19a..e89f0bffb5 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -1,5 +1,8 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -43,8 +46,9 @@ from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -229,6 +233,14 @@ def filter_to_clause(event_filter):
         clauses.append("contains_url = ?")
         args.append(event_filter.contains_url)
 
+    # We're only applying the "labels" filter on the database query, because applying the
+    # "not_labels" filter via a SQL query is non-trivial. Instead, we let
+    # event_filter.check_fields apply it, which is not as efficient but makes the
+    # implementation simpler.
+    if event_filter.labels:
+        clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
+        args.extend(event_filter.labels)
+
     return " AND ".join(clauses), args
 
 
@@ -240,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(StreamWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         events_max = self.get_room_max_stream_ordering()
-        event_cache_prefill, min_event_val = self._get_cache_dict(
+        event_cache_prefill, min_event_val = self.db.get_cache_dict(
             db_conn,
             "events",
             entity_column="room_id",
@@ -334,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             from_key (str): The room_key portion of a StreamToken
         """
         from_key = RoomStreamToken.parse_stream_token(from_key).stream
-        return set(
+        return {
             room_id
             for room_id in room_ids
             if self._events_stream_cache.has_entity_changed(room_id, from_key)
-        )
+        }
 
     @defer.inlineCallbacks
     def get_room_events_stream_for_room(
@@ -389,7 +401,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
             return rows
 
-        rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+        rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
 
         ret = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
@@ -439,7 +451,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return rows
 
-        rows = yield self.runInteraction("get_membership_changes_for_user", f)
+        rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
 
         ret = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
@@ -469,11 +481,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             room_id, limit, end_token
         )
 
-        logger.debug("stream before")
         events = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
-        logger.debug("stream after")
 
         self._set_before_and_after(events, rows)
 
@@ -500,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         end_token = RoomStreamToken.parse(end_token)
 
-        rows, token = yield self.runInteraction(
+        rows, token = yield self.db.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
@@ -513,8 +523,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, token
 
-    def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
-        """Gets details of the first event in a room at or after a stream ordering
+    def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+        """Gets details of the first event in a room at or before a stream ordering
 
         Args:
             room_id (str):
@@ -529,15 +539,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             sql = (
                 "SELECT stream_ordering, topological_ordering, event_id"
                 " FROM events"
-                " WHERE room_id = ? AND stream_ordering >= ?"
+                " WHERE room_id = ? AND stream_ordering <= ?"
                 " AND NOT outlier"
-                " ORDER BY stream_ordering"
+                " ORDER BY stream_ordering DESC"
                 " LIMIT 1"
             )
             txn.execute(sql, (room_id, stream_ordering))
             return txn.fetchone()
 
-        return self.runInteraction("get_room_event_after_stream_ordering", _f)
+        return self.db.runInteraction("get_room_event_before_stream_ordering", _f)
 
     @defer.inlineCallbacks
     def get_room_events_max_id(self, room_id=None):
@@ -551,7 +561,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if room_id is None:
             return "s%d" % (token,)
         else:
-            topo = yield self.runInteraction(
+            topo = yield self.db.runInteraction(
                 "_get_max_topological_txn", self._get_max_topological_txn, room_id
             )
             return "t%d-%d" % (topo, token)
@@ -565,7 +575,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred "s%d" stream token.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
         ).addCallback(lambda row: "s%d" % (row,))
 
@@ -578,7 +588,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred "t%d-%d" topological token.
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
             retcols=("stream_ordering", "topological_ordering"),
@@ -602,13 +612,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "SELECT coalesce(max(topological_ordering), 0) FROM events"
             " WHERE room_id = ? AND stream_ordering < ?"
         )
-        return self._execute(
+        return self.db.execute(
             "get_max_topological_token", None, sql, room_id, stream_key
         ).addCallback(lambda r: r[0][0] if r else 0)
 
     def _get_max_topological_txn(self, txn, room_id):
         txn.execute(
-            "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+            "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
             (room_id,),
         )
 
@@ -656,7 +666,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = yield self.runInteraction(
+        results = yield self.db.runInteraction(
             "get_events_around",
             self._get_events_around_txn,
             room_id,
@@ -667,11 +677,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         events_before = yield self.get_events_as_list(
-            [e for e in results["before"]["event_ids"]], get_prev_content=True
+            list(results["before"]["event_ids"]), get_prev_content=True
         )
 
         events_after = yield self.get_events_as_list(
-            [e for e in results["after"]["event_ids"]], get_prev_content=True
+            list(results["after"]["event_ids"]), get_prev_content=True
         )
 
         return {
@@ -698,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = self._simple_select_one_txn(
+        results = self.db.simple_select_one_txn(
             txn,
             "events",
             keyvalues={"event_id": event_id, "room_id": room_id},
@@ -777,7 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.runInteraction(
+        upper_bound, event_ids = yield self.db.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
@@ -786,7 +796,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return upper_bound, events
 
     def get_federation_out_pos(self, typ):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="federation_stream_position",
             retcol="stream_id",
             keyvalues={"type": typ},
@@ -794,7 +804,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     def update_federation_out_pos(self, typ, stream_id):
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="federation_stream_position",
             keyvalues={"type": typ},
             updatevalues={"stream_id": stream_id},
@@ -863,13 +873,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         args.append(int(limit))
 
-        sql = (
-            "SELECT event_id, topological_ordering, stream_ordering"
-            " FROM events"
-            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
-            " ORDER BY topological_ordering %(order)s,"
-            " stream_ordering %(order)s LIMIT ?"
-        ) % {"bounds": bounds, "order": order}
+        select_keywords = "SELECT"
+        join_clause = ""
+        if event_filter and event_filter.labels:
+            # If we're not filtering on a label, then joining on event_labels will
+            # return as many row for a single event as the number of labels it has. To
+            # avoid this, only join if we're filtering on at least one label.
+            join_clause = """
+                LEFT JOIN event_labels
+                USING (event_id, room_id, topological_ordering)
+            """
+            if len(event_filter.labels) > 1:
+                # Using DISTINCT in this SELECT query is quite expensive, because it
+                # requires the engine to sort on the entire (not limited) result set,
+                # i.e. the entire events table. We only need to use it when we're
+                # filtering on more than two labels, because that's the only scenario
+                # in which we can possibly to get multiple times the same event ID in
+                # the results.
+                select_keywords += "DISTINCT"
+
+        sql = """
+            %(select_keywords)s event_id, topological_ordering, stream_ordering
+            FROM events
+            %(join_clause)s
+            WHERE outlier = ? AND room_id = ? AND %(bounds)s
+            ORDER BY topological_ordering %(order)s,
+            stream_ordering %(order)s LIMIT ?
+        """ % {
+            "select_keywords": select_keywords,
+            "join_clause": join_clause,
+            "bounds": bounds,
+            "order": order,
+        }
 
         txn.execute(sql, args)
 
@@ -920,7 +955,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if to_key:
             to_key = RoomStreamToken.parse(to_key)
 
-        rows, token = yield self.runInteraction(
+        rows, token = yield self.db.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
diff --git a/synapse/storage/tags.py b/synapse/storage/data_stores/main/tags.py
index 20dd6bd53d..4219018302 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -22,7 +22,7 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             tag strings to tag content.
         """
 
-        deferred = self._simple_select_list(
+        deferred = self.db.simple_select_list(
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
@@ -78,14 +78,12 @@ class TagsWorkerStore(AccountDataWorkerStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        tag_ids = yield self.runInteraction(
+        tag_ids = yield self.db.runInteraction(
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
         def get_tag_content(txn, tag_ids):
-            sql = (
-                "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
-            )
+            sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
             results = []
             for stream_id, user_id, room_id in tag_ids:
                 txn.execute(sql, (user_id, room_id))
@@ -100,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         batch_size = 50
         results = []
         for i in range(0, len(tag_ids), batch_size):
-            tags = yield self.runInteraction(
+            tags = yield self.db.runInteraction(
                 "get_all_updated_tag_content",
                 get_tag_content,
                 tag_ids[i : i + batch_size],
@@ -137,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
         if not changed:
             return {}
 
-        room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
+        room_ids = yield self.db.runInteraction(
+            "get_updated_tags", get_updated_tags_txn
+        )
 
         results = {}
         if room_ids:
@@ -155,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         Returns:
             A deferred list of string tags.
         """
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="room_tags",
             keyvalues={"user_id": user_id, "room_id": room_id},
             retcols=("tag", "content"),
@@ -180,7 +180,7 @@ class TagsStore(TagsWorkerStore):
         content_json = json.dumps(content)
 
         def add_tag_txn(txn, next_id):
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="room_tags",
                 keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -189,7 +189,7 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction("add_tag", add_tag_txn, next_id)
+            yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
@@ -212,7 +212,7 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
+            yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
@@ -233,6 +233,9 @@ class TagsStore(TagsWorkerStore):
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
         )
 
+        # Note: This is only here for backwards compat to allow admins to
+        # roll back to a previous Synapse version. Next time we update the
+        # database version we can remove this table.
         update_max_id_sql = (
             "UPDATE account_data_max_stream_id"
             " SET stream_id = ?"
diff --git a/synapse/storage/transactions.py b/synapse/storage/data_stores/main/transactions.py
index 289c117396..a9bf457939 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -16,23 +16,16 @@
 import logging
 from collections import namedtuple
 
-import six
-
 from canonicaljson import encode_canonical_json
 
 from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
-from ._base import SQLBaseStore, db_to_json
-
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
+db_binary_type = memoryview
 
 logger = logging.getLogger(__name__)
 
@@ -53,8 +46,8 @@ class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
     """
 
-    def __init__(self, db_conn, hs):
-        super(TransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TransactionStore, self).__init__(database, db_conn, hs)
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
@@ -78,7 +71,7 @@ class TransactionStore(SQLBaseStore):
             this transaction or a 2-tuple of (int, dict)
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_received_txn_response",
             self._get_received_txn_response,
             transaction_id,
@@ -86,7 +79,7 @@ class TransactionStore(SQLBaseStore):
         )
 
     def _get_received_txn_response(self, txn, transaction_id, origin):
-        result = self._simple_select_one_txn(
+        result = self.db.simple_select_one_txn(
             txn,
             table="received_transactions",
             keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -120,7 +113,7 @@ class TransactionStore(SQLBaseStore):
             response_json (str)
         """
 
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="received_transactions",
             values={
                 "transaction_id": transaction_id,
@@ -149,7 +142,7 @@ class TransactionStore(SQLBaseStore):
         if result is not SENTINEL:
             return result
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings,
             destination,
@@ -161,7 +154,7 @@ class TransactionStore(SQLBaseStore):
         return result
 
     def _get_destination_retry_timings(self, txn, destination):
-        result = self._simple_select_one_txn(
+        result = self.db.simple_select_one_txn(
             txn,
             table="destinations",
             keyvalues={"destination": destination},
@@ -188,7 +181,7 @@ class TransactionStore(SQLBaseStore):
         """
 
         self._destination_retry_cache.pop(destination, None)
-        return self.runInteraction(
+        return self.db.runInteraction(
             "set_destination_retry_timings",
             self._set_destination_retry_timings,
             destination,
@@ -228,7 +221,7 @@ class TransactionStore(SQLBaseStore):
         # We need to be careful here as the data may have changed from under us
         # due to a worker setting the timings.
 
-        prev_row = self._simple_select_one_txn(
+        prev_row = self.db.simple_select_one_txn(
             txn,
             table="destinations",
             keyvalues={"destination": destination},
@@ -237,7 +230,7 @@ class TransactionStore(SQLBaseStore):
         )
 
         if not prev_row:
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="destinations",
                 values={
@@ -248,7 +241,7 @@ class TransactionStore(SQLBaseStore):
                 },
             )
         elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 "destinations",
                 keyvalues={"destination": destination},
@@ -271,4 +264,6 @@ class TransactionStore(SQLBaseStore):
         def _cleanup_transactions_txn(txn):
             txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
 
-        return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
+        return self.db.runInteraction(
+            "_cleanup_transactions", _cleanup_transactions_txn
+        )
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
new file mode 100644
index 0000000000..1d8ee22fb1
--- /dev/null
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -0,0 +1,300 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from typing import Any, Dict, Optional, Union
+
+import attr
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.types import JsonDict
+
+
+@attr.s
+class UIAuthSessionData:
+    session_id = attr.ib(type=str)
+    # The dictionary from the client root level, not the 'auth' key.
+    clientdict = attr.ib(type=JsonDict)
+    # The URI and method the session was intiatied with. These are checked at
+    # each stage of the authentication to ensure that the asked for operation
+    # has not changed.
+    uri = attr.ib(type=str)
+    method = attr.ib(type=str)
+    # A string description of the operation that the current authentication is
+    # authorising.
+    description = attr.ib(type=str)
+
+
+class UIAuthWorkerStore(SQLBaseStore):
+    """
+    Manage user interactive authentication sessions.
+    """
+
+    async def create_ui_auth_session(
+        self, clientdict: JsonDict, uri: str, method: str, description: str,
+    ) -> UIAuthSessionData:
+        """
+        Creates a new user interactive authentication session.
+
+        The session can be used to track the stages necessary to authenticate a
+        user across multiple HTTP requests.
+
+        Args:
+            clientdict:
+                The dictionary from the client root level, not the 'auth' key.
+            uri:
+                The URI this session was initiated with, this is checked at each
+                stage of the authentication to ensure that the asked for
+                operation has not changed.
+            method:
+                The method this session was initiated with, this is checked at each
+                stage of the authentication to ensure that the asked for
+                operation has not changed.
+            description:
+                A string description of the operation that the current
+                authentication is authorising.
+        Returns:
+            The newly created session.
+        Raises:
+            StoreError if a unique session ID cannot be generated.
+        """
+        # The clientdict gets stored as JSON.
+        clientdict_json = json.dumps(clientdict)
+
+        # autogen a session ID and try to create it. We may clash, so just
+        # try a few times till one goes through, giving up eventually.
+        attempts = 0
+        while attempts < 5:
+            session_id = stringutils.random_string(24)
+
+            try:
+                await self.db.simple_insert(
+                    table="ui_auth_sessions",
+                    values={
+                        "session_id": session_id,
+                        "clientdict": clientdict_json,
+                        "uri": uri,
+                        "method": method,
+                        "description": description,
+                        "serverdict": "{}",
+                        "creation_time": self.hs.get_clock().time_msec(),
+                    },
+                    desc="create_ui_auth_session",
+                )
+                return UIAuthSessionData(
+                    session_id, clientdict, uri, method, description
+                )
+            except self.db.engine.module.IntegrityError:
+                attempts += 1
+        raise StoreError(500, "Couldn't generate a session ID.")
+
+    async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
+        """Retrieve a UI auth session.
+
+        Args:
+            session_id: The ID of the session.
+        Returns:
+            A dict containing the device information.
+        Raises:
+            StoreError if the session is not found.
+        """
+        result = await self.db.simple_select_one(
+            table="ui_auth_sessions",
+            keyvalues={"session_id": session_id},
+            retcols=("clientdict", "uri", "method", "description"),
+            desc="get_ui_auth_session",
+        )
+
+        result["clientdict"] = json.loads(result["clientdict"])
+
+        return UIAuthSessionData(session_id, **result)
+
+    async def mark_ui_auth_stage_complete(
+        self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+    ):
+        """
+        Mark a session stage as completed.
+
+        Args:
+            session_id: The ID of the corresponding session.
+            stage_type: The completed stage type.
+            result: The result of the stage verification.
+        Raises:
+            StoreError if the session cannot be found.
+        """
+        # Add (or update) the results of the current stage to the database.
+        #
+        # Note that we need to allow for the same stage to complete multiple
+        # times here so that registration is idempotent.
+        try:
+            await self.db.simple_upsert(
+                table="ui_auth_sessions_credentials",
+                keyvalues={"session_id": session_id, "stage_type": stage_type},
+                values={"result": json.dumps(result)},
+                desc="mark_ui_auth_stage_complete",
+            )
+        except self.db.engine.module.IntegrityError:
+            raise StoreError(400, "Unknown session ID: %s" % (session_id,))
+
+    async def get_completed_ui_auth_stages(
+        self, session_id: str
+    ) -> Dict[str, Union[str, bool, JsonDict]]:
+        """
+        Retrieve the completed stages of a UI authentication session.
+
+        Args:
+            session_id: The ID of the session.
+        Returns:
+            The completed stages mapped to the result of the verification of
+            that auth-type.
+        """
+        results = {}
+        for row in await self.db.simple_select_list(
+            table="ui_auth_sessions_credentials",
+            keyvalues={"session_id": session_id},
+            retcols=("stage_type", "result"),
+            desc="get_completed_ui_auth_stages",
+        ):
+            results[row["stage_type"]] = json.loads(row["result"])
+
+        return results
+
+    async def set_ui_auth_clientdict(
+        self, session_id: str, clientdict: JsonDict
+    ) -> None:
+        """
+        Store an updated clientdict for a given session ID.
+
+        Args:
+            session_id: The ID of this session as returned from check_auth
+            clientdict:
+                The dictionary from the client root level, not the 'auth' key.
+        """
+        # The clientdict gets stored as JSON.
+        clientdict_json = json.dumps(clientdict)
+
+        self.db.simple_update_one(
+            table="ui_auth_sessions",
+            keyvalues={"session_id": session_id},
+            updatevalues={"clientdict": clientdict_json},
+            desc="set_ui_auth_client_dict",
+        )
+
+    async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
+        """
+        Store a key-value pair into the sessions data associated with this
+        request. This data is stored server-side and cannot be modified by
+        the client.
+
+        Args:
+            session_id: The ID of this session as returned from check_auth
+            key: The key to store the data under
+            value: The data to store
+        Raises:
+            StoreError if the session cannot be found.
+        """
+        await self.db.runInteraction(
+            "set_ui_auth_session_data",
+            self._set_ui_auth_session_data_txn,
+            session_id,
+            key,
+            value,
+        )
+
+    def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+        # Get the current value.
+        result = self.db.simple_select_one_txn(
+            txn,
+            table="ui_auth_sessions",
+            keyvalues={"session_id": session_id},
+            retcols=("serverdict",),
+        )
+
+        # Update it and add it back to the database.
+        serverdict = json.loads(result["serverdict"])
+        serverdict[key] = value
+
+        self.db.simple_update_one_txn(
+            txn,
+            table="ui_auth_sessions",
+            keyvalues={"session_id": session_id},
+            updatevalues={"serverdict": json.dumps(serverdict)},
+        )
+
+    async def get_ui_auth_session_data(
+        self, session_id: str, key: str, default: Optional[Any] = None
+    ) -> Any:
+        """
+        Retrieve data stored with set_session_data
+
+        Args:
+            session_id: The ID of this session as returned from check_auth
+            key: The key to store the data under
+            default: Value to return if the key has not been set
+        Raises:
+            StoreError if the session cannot be found.
+        """
+        result = await self.db.simple_select_one(
+            table="ui_auth_sessions",
+            keyvalues={"session_id": session_id},
+            retcols=("serverdict",),
+            desc="get_ui_auth_session_data",
+        )
+
+        serverdict = json.loads(result["serverdict"])
+
+        return serverdict.get(key, default)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+    def delete_old_ui_auth_sessions(self, expiration_time: int):
+        """
+        Remove sessions which were last used earlier than the expiration time.
+
+        Args:
+            expiration_time: The latest time that is still considered valid.
+                This is an epoch time in milliseconds.
+
+        """
+        return self.db.runInteraction(
+            "delete_old_ui_auth_sessions",
+            self._delete_old_ui_auth_sessions_txn,
+            expiration_time,
+        )
+
+    def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+        # Get the expired sessions.
+        sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
+        txn.execute(sql, [expiration_time])
+        session_ids = [r[0] for r in txn.fetchall()]
+
+        # Delete the corresponding completed credentials.
+        self.db.simple_delete_many_txn(
+            txn,
+            table="ui_auth_sessions_credentials",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={},
+        )
+
+        # Finally, delete the sessions.
+        self.db.simple_delete_many_txn(
+            txn,
+            table="ui_auth_sessions",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={},
+        )
diff --git a/synapse/storage/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index b5188d9bee..6b8130bf0f 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -19,10 +19,10 @@ import re
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.data_stores.main.state import StateFilter
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.state import StateFilter
-from synapse.storage.state_deltas import StateDeltasStore
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
 
@@ -32,30 +32,30 @@ logger = logging.getLogger(__name__)
 TEMP_TABLE = "_temp_populate_user_directory"
 
 
-class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
+class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_createtables",
             self._populate_user_directory_createtables,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_process_rooms",
             self._populate_user_directory_process_rooms,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_process_users",
             self._populate_user_directory_process_users,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_cleanup", self._populate_user_directory_cleanup
         )
 
@@ -85,7 +85,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
             """
             txn.execute(sql)
             rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
-            self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+            self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
             del rooms
 
             # If search all users is on, get all the users we want to add.
@@ -100,15 +100,17 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
                 txn.execute("SELECT name FROM users")
                 users = [{"user_id": x[0]} for x in txn.fetchall()]
 
-                self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+                self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
 
         new_pos = yield self.get_max_stream_id_in_current_state_deltas()
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_user_directory_temp_build", _make_staging_area
         )
-        yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+        yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
 
-        yield self._end_background_update("populate_user_directory_createtables")
+        yield self.db.updates._end_background_update(
+            "populate_user_directory_createtables"
+        )
         return 1
 
     @defer.inlineCallbacks
@@ -116,7 +118,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         """
         Update the user directory stream position, then clean up the old tables.
         """
-        position = yield self._simple_select_one_onecol(
+        position = yield self.db.simple_select_one_onecol(
             TEMP_TABLE + "_position", None, "position"
         )
         yield self.update_user_directory_stream_pos(position)
@@ -126,11 +128,11 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_user_directory_cleanup", _delete_staging_area
         )
 
-        yield self._end_background_update("populate_user_directory_cleanup")
+        yield self.db.updates._end_background_update("populate_user_directory_cleanup")
         return 1
 
     @defer.inlineCallbacks
@@ -170,16 +172,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
             return rooms_to_work_on
 
-        rooms_to_work_on = yield self.runInteraction(
+        rooms_to_work_on = yield self.db.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self._end_background_update("populate_user_directory_process_rooms")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_rooms"
+            )
             return 1
 
-        logger.info(
+        logger.debug(
             "Processing the next %d rooms of %d remaining"
             % (len(rooms_to_work_on), progress["remaining"])
         )
@@ -243,12 +247,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
                         to_insert.clear()
 
             # We've finished a room. Delete it from the table.
-            yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+            yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "populate_user_directory",
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 "populate_user_directory_process_rooms",
                 progress,
             )
@@ -267,7 +271,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         If search_all_users is enabled, add all of the users to the user directory.
         """
         if not self.hs.config.user_directory_search_all_users:
-            yield self._end_background_update("populate_user_directory_process_users")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_users"
+            )
             return 1
 
         def _get_next_batch(txn):
@@ -291,16 +297,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
             return users_to_work_on
 
-        users_to_work_on = yield self.runInteraction(
+        users_to_work_on = yield self.db.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more users -- complete the transaction.
         if not users_to_work_on:
-            yield self._end_background_update("populate_user_directory_process_users")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_users"
+            )
             return 1
 
-        logger.info(
+        logger.debug(
             "Processing the next %d users of %d remaining"
             % (len(users_to_work_on), progress["remaining"])
         )
@@ -312,12 +320,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
             )
 
             # We've finished processing a user. Delete it from the table.
-            yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+            yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "populate_user_directory",
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 "populate_user_directory_process_users",
                 progress,
             )
@@ -361,7 +369,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         """
 
         def _update_profile_in_user_dir_txn(txn):
-            new_entry = self._simple_upsert_txn(
+            new_entry = self.db.simple_upsert_txn(
                 txn,
                 table="user_directory",
                 keyvalues={"user_id": user_id},
@@ -435,7 +443,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
                         )
             elif isinstance(self.database_engine, Sqlite3Engine):
                 value = "%s %s" % (user_id, display_name) if display_name else user_id
-                self._simple_upsert_txn(
+                self.db.simple_upsert_txn(
                     txn,
                     table="user_directory_search",
                     keyvalues={"user_id": user_id},
@@ -448,59 +456,10 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
-    def remove_from_user_dir(self, user_id):
-        def _remove_from_user_dir_txn(txn):
-            self._simple_delete_txn(
-                txn, table="user_directory", keyvalues={"user_id": user_id}
-            )
-            self._simple_delete_txn(
-                txn, table="user_directory_search", keyvalues={"user_id": user_id}
-            )
-            self._simple_delete_txn(
-                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
-            )
-            self._simple_delete_txn(
-                txn,
-                table="users_who_share_private_rooms",
-                keyvalues={"user_id": user_id},
-            )
-            self._simple_delete_txn(
-                txn,
-                table="users_who_share_private_rooms",
-                keyvalues={"other_user_id": user_id},
-            )
-            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
-
-        return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
-
-    @defer.inlineCallbacks
-    def get_users_in_dir_due_to_room(self, room_id):
-        """Get all user_ids that are in the room directory because they're
-        in the given room_id
-        """
-        user_ids_share_pub = yield self._simple_select_onecol(
-            table="users_in_public_rooms",
-            keyvalues={"room_id": room_id},
-            retcol="user_id",
-            desc="get_users_in_dir_due_to_room",
-        )
-
-        user_ids_share_priv = yield self._simple_select_onecol(
-            table="users_who_share_private_rooms",
-            keyvalues={"room_id": room_id},
-            retcol="other_user_id",
-            desc="get_users_in_dir_due_to_room",
-        )
-
-        user_ids = set(user_ids_share_pub)
-        user_ids.update(user_ids_share_priv)
-
-        return user_ids
-
     def add_users_who_share_private_room(self, room_id, user_id_tuples):
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
@@ -511,7 +470,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         """
 
         def _add_users_who_share_room_txn(txn):
-            self._simple_upsert_many_txn(
+            self.db.simple_upsert_many_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 key_names=["user_id", "other_user_id", "room_id"],
@@ -523,7 +482,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
                 value_values=None,
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
@@ -538,7 +497,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
         def _add_users_in_public_rooms_txn(txn):
 
-            self._simple_upsert_many_txn(
+            self.db.simple_upsert_many_txn(
                 txn,
                 table="users_in_public_rooms",
                 key_names=["user_id", "room_id"],
@@ -547,10 +506,102 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
                 value_values=None,
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_users_in_public_rooms", _add_users_in_public_rooms_txn
         )
 
+    def delete_all_from_user_dir(self):
+        """Delete the entire user directory
+        """
+
+        def _delete_all_from_user_dir_txn(txn):
+            txn.execute("DELETE FROM user_directory")
+            txn.execute("DELETE FROM user_directory_search")
+            txn.execute("DELETE FROM users_in_public_rooms")
+            txn.execute("DELETE FROM users_who_share_private_rooms")
+            txn.call_after(self.get_user_in_directory.invalidate_all)
+
+        return self.db.runInteraction(
+            "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+        )
+
+    @cached()
+    def get_user_in_directory(self, user_id):
+        return self.db.simple_select_one(
+            table="user_directory",
+            keyvalues={"user_id": user_id},
+            retcols=("display_name", "avatar_url"),
+            allow_none=True,
+            desc="get_user_in_directory",
+        )
+
+    def update_user_directory_stream_pos(self, stream_id):
+        return self.db.simple_update_one(
+            table="user_directory_stream_pos",
+            keyvalues={},
+            updatevalues={"stream_id": stream_id},
+            desc="update_user_directory_stream_pos",
+        )
+
+
+class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
+
+    # How many records do we calculate before sending it to
+    # add_users_who_share_private_rooms?
+    SHARE_PRIVATE_WORKING_SET = 500
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+
+    def remove_from_user_dir(self, user_id):
+        def _remove_from_user_dir_txn(txn):
+            self.db.simple_delete_txn(
+                txn, table="user_directory", keyvalues={"user_id": user_id}
+            )
+            self.db.simple_delete_txn(
+                txn, table="user_directory_search", keyvalues={"user_id": user_id}
+            )
+            self.db.simple_delete_txn(
+                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
+            )
+            self.db.simple_delete_txn(
+                txn,
+                table="users_who_share_private_rooms",
+                keyvalues={"user_id": user_id},
+            )
+            self.db.simple_delete_txn(
+                txn,
+                table="users_who_share_private_rooms",
+                keyvalues={"other_user_id": user_id},
+            )
+            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+        return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+
+    @defer.inlineCallbacks
+    def get_users_in_dir_due_to_room(self, room_id):
+        """Get all user_ids that are in the room directory because they're
+        in the given room_id
+        """
+        user_ids_share_pub = yield self.db.simple_select_onecol(
+            table="users_in_public_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="user_id",
+            desc="get_users_in_dir_due_to_room",
+        )
+
+        user_ids_share_priv = yield self.db.simple_select_onecol(
+            table="users_who_share_private_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="other_user_id",
+            desc="get_users_in_dir_due_to_room",
+        )
+
+        user_ids = set(user_ids_share_pub)
+        user_ids.update(user_ids_share_priv)
+
+        return user_ids
+
     def remove_user_who_share_room(self, user_id, room_id):
         """
         Deletes entries in the users_who_share_*_rooms table. The first
@@ -562,23 +613,23 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         """
 
         def _remove_user_who_share_room_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"other_user_id": user_id, "room_id": room_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_in_public_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
@@ -593,14 +644,14 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
         Returns:
             list: user_id
         """
-        rows = yield self._simple_select_onecol(
+        rows = yield self.db.simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
             desc="get_rooms_user_is_in",
         )
 
-        pub_rows = yield self._simple_select_onecol(
+        pub_rows = yield self.db.simple_select_onecol(
             table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
@@ -631,53 +682,20 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
             ) f2 USING (room_id)
         """
 
-        rows = yield self._execute(
+        rows = yield self.db.execute(
             "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
         )
 
         return [room_id for room_id, in rows]
 
-    def delete_all_from_user_dir(self):
-        """Delete the entire user directory
-        """
-
-        def _delete_all_from_user_dir_txn(txn):
-            txn.execute("DELETE FROM user_directory")
-            txn.execute("DELETE FROM user_directory_search")
-            txn.execute("DELETE FROM users_in_public_rooms")
-            txn.execute("DELETE FROM users_who_share_private_rooms")
-            txn.call_after(self.get_user_in_directory.invalidate_all)
-
-        return self.runInteraction(
-            "delete_all_from_user_dir", _delete_all_from_user_dir_txn
-        )
-
-    @cached()
-    def get_user_in_directory(self, user_id):
-        return self._simple_select_one(
-            table="user_directory",
-            keyvalues={"user_id": user_id},
-            retcols=("display_name", "avatar_url"),
-            allow_none=True,
-            desc="get_user_in_directory",
-        )
-
     def get_user_directory_stream_pos(self):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
             retcol="stream_id",
             desc="get_user_directory_stream_pos",
         )
 
-    def update_user_directory_stream_pos(self, stream_id):
-        return self._simple_update_one(
-            table="user_directory_stream_pos",
-            keyvalues={},
-            updatevalues={"stream_id": stream_id},
-            desc="update_user_directory_stream_pos",
-        )
-
     @defer.inlineCallbacks
     def search_user_dir(self, user_id, search_term, limit):
         """Searches for users in directory
@@ -776,8 +794,8 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = yield self._execute(
-            "search_user_dir", self.cursor_to_dict, sql, *args
+        results = yield self.db.execute(
+            "search_user_dir", self.db.cursor_to_dict, sql, *args
         )
 
         limited = len(results) > limit
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index 05cabc2282..ec6b8a4ffd 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if the user has requested erasure
         """
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="erased_users",
             keyvalues={"user_id": user_id},
             retcol="1",
@@ -56,16 +56,16 @@ class UserErasureWorkerStore(SQLBaseStore):
         # iterate it multiple times, and (b) avoiding duplicates.
         user_ids = tuple(set(user_ids))
 
-        def _get_erased_users(txn):
-            txn.execute(
-                "SELECT user_id FROM erased_users WHERE user_id IN (%s)"
-                % (",".join("?" * len(user_ids))),
-                user_ids,
-            )
-            return set(r[0] for r in txn)
-
-        erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
-        res = dict((u, u in erased_users) for u in user_ids)
+        rows = yield self.db.simple_select_many_batch(
+            table="erased_users",
+            column="user_id",
+            iterable=user_ids,
+            retcols=("user_id",),
+            desc="are_users_erased",
+        )
+        erased_users = {row["user_id"] for row in rows}
+
+        res = {u: u in erased_users for u in user_ids}
         return res
 
 
@@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.runInteraction("mark_user_erased", f)
+        return self.db.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py
new file mode 100644
index 0000000000..86e09f6229
--- /dev/null
+++ b/synapse/storage/data_stores/state/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.data_stores.state.store import StateGroupDataStore  # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
new file mode 100644
index 0000000000..ff000bc9ec
--- /dev/null
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+    """Defines functions related to state groups needed to run the state backgroud
+    updates.
+    """
+
+    def _count_state_group_hops_txn(self, txn, state_group):
+        """Given a state group, count how many hops there are in the tree.
+
+        This is used to ensure the delta chains don't get too long.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = """
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self.db.simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
+    def _get_state_groups_from_groups_txn(
+        self, txn, groups, state_filter=StateFilter.all()
+    ):
+        results = {group: {} for group in groups}
+
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        # Unless the filter clause is empty, we're going to append it after an
+        # existing where clause
+        if where_clause:
+            where_clause = " AND (%s)" % (where_clause,)
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # Temporarily disable sequential scans in this transaction. This is
+            # a temporary hack until we can add the right indices in
+            txn.execute("SET LOCAL enable_seqscan=off")
+
+            # The below query walks the state_group tree so that the "state"
+            # table includes all state_groups in the tree. It then joins
+            # against `state_groups_state` to fetch the latest state.
+            # It assumes that previous state groups are always numerically
+            # lesser.
+            # The PARTITION is used to get the event_id in the greatest state
+            # group for the given type, state_key.
+            # This may return multiple rows per (type, state_key), but last_value
+            # should be the same.
+            sql = """
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT DISTINCT ON (type, state_key)
+                    type, state_key, event_id
+                FROM state_groups_state
+                WHERE state_group IN (
+                    SELECT state_group FROM state
+                ) %s
+                ORDER BY type, state_key, state_group DESC
+            """
+
+            for group in groups:
+                args = [group]
+                args.extend(where_args)
+
+                txn.execute(sql % (where_clause,), args)
+                for row in txn:
+                    typ, state_key, event_id = row
+                    key = (typ, state_key)
+                    results[group][key] = event_id
+        else:
+            max_entries_returned = state_filter.max_entries_returned()
+
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            for group in groups:
+                next_group = group
+
+                while next_group:
+                    # We did this before by getting the list of group ids, and
+                    # then passing that list to sqlite to get latest event for
+                    # each (type, state_key). However, that was terribly slow
+                    # without the right indices (which we can't add until
+                    # after we finish deduping state, which requires this func)
+                    args = [next_group]
+                    args.extend(where_args)
+
+                    txn.execute(
+                        "SELECT type, state_key, event_id FROM state_groups_state"
+                        " WHERE state_group = ? " + where_clause,
+                        args,
+                    )
+                    results[group].update(
+                        ((typ, state_key), event_id)
+                        for typ, state_key, event_id in txn
+                        if (typ, state_key) not in results[group]
+                    )
+
+                    # If the number of entries in the (type,state_key)->event_id dict
+                    # matches the number of (type,state_keys) types we were searching
+                    # for, then we must have found them all, so no need to go walk
+                    # further down the tree... UNLESS our types filter contained
+                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
+                    # search
+                    if (
+                        max_entries_returned is not None
+                        and len(results[group]) == max_entries_returned
+                    ):
+                        break
+
+                    next_group = self.db.simple_select_one_onecol_txn(
+                        txn,
+                        table="state_group_edges",
+                        keyvalues={"state_group": next_group},
+                        retcol="prev_state_group",
+                        allow_none=True,
+                    )
+
+        return results
+
+
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        self.db.updates.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.db.updates.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
+        )
+        self.db.updates.register_background_index_update(
+            self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
+            index_name="state_groups_room_id_idx",
+            table="state_groups",
+            columns=["room_id"],
+        )
+
+    @defer.inlineCallbacks
+    def _background_deduplicate_state(self, progress, batch_size):
+        """This background update will slowly deduplicate state by reencoding
+        them as deltas.
+        """
+        last_state_group = progress.get("last_state_group", 0)
+        rows_inserted = progress.get("rows_inserted", 0)
+        max_group = progress.get("max_group", None)
+
+        BATCH_SIZE_SCALE_FACTOR = 100
+
+        batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
+
+        if max_group is None:
+            rows = yield self.db.execute(
+                "_background_deduplicate_state",
+                None,
+                "SELECT coalesce(max(id), 0) FROM state_groups",
+            )
+            max_group = rows[0][0]
+
+        def reindex_txn(txn):
+            new_last_state_group = last_state_group
+            for count in range(batch_size):
+                txn.execute(
+                    "SELECT id, room_id FROM state_groups"
+                    " WHERE ? < id AND id <= ?"
+                    " ORDER BY id ASC"
+                    " LIMIT 1",
+                    (new_last_state_group, max_group),
+                )
+                row = txn.fetchone()
+                if row:
+                    state_group, room_id = row
+
+                if not row or not state_group:
+                    return True, count
+
+                txn.execute(
+                    "SELECT state_group FROM state_group_edges"
+                    " WHERE state_group = ?",
+                    (state_group,),
+                )
+
+                # If we reach a point where we've already started inserting
+                # edges we should stop.
+                if txn.fetchall():
+                    return True, count
+
+                txn.execute(
+                    "SELECT coalesce(max(id), 0) FROM state_groups"
+                    " WHERE id < ? AND room_id = ?",
+                    (state_group, room_id),
+                )
+                (prev_group,) = txn.fetchone()
+                new_last_state_group = state_group
+
+                if prev_group:
+                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+                    if potential_hops >= MAX_STATE_DELTA_HOPS:
+                        # We want to ensure chains are at most this long,#
+                        # otherwise read performance degrades.
+                        continue
+
+                    prev_state = self._get_state_groups_from_groups_txn(
+                        txn, [prev_group]
+                    )
+                    prev_state = prev_state[prev_group]
+
+                    curr_state = self._get_state_groups_from_groups_txn(
+                        txn, [state_group]
+                    )
+                    curr_state = curr_state[state_group]
+
+                    if not set(prev_state.keys()) - set(curr_state.keys()):
+                        # We can only do a delta if the current has a strict super set
+                        # of keys
+
+                        delta_state = {
+                            key: value
+                            for key, value in iteritems(curr_state)
+                            if prev_state.get(key, None) != value
+                        }
+
+                        self.db.simple_delete_txn(
+                            txn,
+                            table="state_group_edges",
+                            keyvalues={"state_group": state_group},
+                        )
+
+                        self.db.simple_insert_txn(
+                            txn,
+                            table="state_group_edges",
+                            values={
+                                "state_group": state_group,
+                                "prev_state_group": prev_group,
+                            },
+                        )
+
+                        self.db.simple_delete_txn(
+                            txn,
+                            table="state_groups_state",
+                            keyvalues={"state_group": state_group},
+                        )
+
+                        self.db.simple_insert_many_txn(
+                            txn,
+                            table="state_groups_state",
+                            values=[
+                                {
+                                    "state_group": state_group,
+                                    "room_id": room_id,
+                                    "type": key[0],
+                                    "state_key": key[1],
+                                    "event_id": state_id,
+                                }
+                                for key, state_id in iteritems(delta_state)
+                            ],
+                        )
+
+            progress = {
+                "last_state_group": state_group,
+                "rows_inserted": rows_inserted + batch_size,
+                "max_group": max_group,
+            }
+
+            self.db.updates._background_update_progress_txn(
+                txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
+            )
+
+            return False, batch_size
+
+        finished, result = yield self.db.runInteraction(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+        )
+
+        if finished:
+            yield self.db.updates._end_background_update(
+                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+            )
+
+        return result * BATCH_SIZE_SCALE_FACTOR
+
+    @defer.inlineCallbacks
+    def _background_index_state(self, progress, batch_size):
+        def reindex_txn(conn):
+            conn.rollback()
+            if isinstance(self.database_engine, PostgresEngine):
+                # postgres insists on autocommit for the index
+                conn.set_session(autocommit=True)
+                try:
+                    txn = conn.cursor()
+                    txn.execute(
+                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+                        " ON state_groups_state(state_group, type, state_key)"
+                    )
+                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+                finally:
+                    conn.set_session(autocommit=False)
+            else:
+                txn = conn.cursor()
+                txn.execute(
+                    "CREATE INDEX state_groups_state_type_idx"
+                    " ON state_groups_state(state_group, type, state_key)"
+                )
+                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+
+        yield self.db.runWithConnection(reindex_txn)
+
+        yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+        return 1
diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/schema/delta/30/state_stream.sql
+++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
new file mode 100644
index 0000000000..1450313bfa
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- The following indices are redundant, other indices are equivalent or
+-- supersets
+DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
index 0fce26345b..33980d02f0 100644
--- a/synapse/storage/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
@@ -13,8 +13,5 @@
  * limitations under the License.
  */
 
-
-ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
-
 INSERT into background_updates (update_name, progress_json, depends_on)
     VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication');
diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/schema/delta/35/state.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql
diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
index 9fd1ccf6f7..9fd1ccf6f7 100644
--- a/synapse/storage/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
new file mode 100644
index 0000000000..7916ef18b2
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+    ('state_groups_room_id_idx', '{}');
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..35f97d6b3d
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
@@ -0,0 +1,37 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_groups (
+    id BIGINT PRIMARY KEY,
+    room_id TEXT NOT NULL,
+    event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_groups_state (
+    state_group BIGINT NOT NULL,
+    room_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    state_key TEXT NOT NULL,
+    event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_group_edges (
+    state_group BIGINT NOT NULL,
+    prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges (state_group);
+CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group);
+CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key);
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
new file mode 100644
index 0000000000..fcd926c9fb
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE state_group_id_seq
+    START WITH 1
+    INCREMENT BY 1
+    NO MINVALUE
+    NO MAXVALUE
+    CACHE 1;
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
new file mode 100644
index 0000000000..f3ad1e4369
--- /dev/null
+++ b/synapse/storage/data_stores/state/store.py
@@ -0,0 +1,642 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
+
+from six import iteritems
+from six.moves import range
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.types import StateMap
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+    """Return type of get_state_group_delta that implements __len__, which lets
+    us use the itrable flag when caching
+    """
+
+    __slots__ = []
+
+    def __len__(self):
+        return len(self.delta_ids) if self.delta_ids else 0
+
+
+class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
+    """A data store for fetching/storing state groups.
+    """
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+
+        # Originally the state store used a single DictionaryCache to cache the
+        # event IDs for the state types in a given state group to avoid hammering
+        # on the state_group* tables.
+        #
+        # The point of using a DictionaryCache is that it can cache a subset
+        # of the state events for a given state group (i.e. a subset of the keys for a
+        # given dict which is an entry in the cache for a given state group ID).
+        #
+        # However, this poses problems when performing complicated queries
+        # on the store - for instance: "give me all the state for this group, but
+        # limit members to this subset of users", as DictionaryCache's API isn't
+        # rich enough to say "please cache any of these fields, apart from this subset".
+        # This is problematic when lazy loading members, which requires this behaviour,
+        # as without it the cache has no choice but to speculatively load all
+        # state events for the group, which negates the efficiency being sought.
+        #
+        # Rather than overcomplicating DictionaryCache's API, we instead split the
+        # state_group_cache into two halves - one for tracking non-member events,
+        # and the other for tracking member_events.  This means that lazy loading
+        # queries can be made in a cache-friendly manner by querying both caches
+        # separately and then merging the result.  So for the example above, you
+        # would query the members cache for a specific subset of state keys
+        # (which DictionaryCache will handle efficiently and fine) and the non-members
+        # cache for all state (which DictionaryCache will similarly handle fine)
+        # and then just merge the results together.
+        #
+        # We size the non-members cache to be smaller than the members cache as the
+        # vast majority of state in Matrix (today) is member events.
+
+        self._state_group_cache = DictionaryCache(
+            "*stateGroupCache*",
+            # TODO: this hasn't been tuned yet
+            50000,
+        )
+        self._state_group_members_cache = DictionaryCache(
+            "*stateGroupMembersCache*", 500000,
+        )
+
+    @cached(max_entries=10000, iterable=True)
+    def get_state_group_delta(self, state_group):
+        """Given a state group try to return a previous group and a delta between
+        the old and the new.
+
+        Returns:
+            (prev_group, delta_ids), where both may be None.
+        """
+
+        def _get_state_group_delta_txn(txn):
+            prev_group = self.db.simple_select_one_onecol_txn(
+                txn,
+                table="state_group_edges",
+                keyvalues={"state_group": state_group},
+                retcol="prev_state_group",
+                allow_none=True,
+            )
+
+            if not prev_group:
+                return _GetStateGroupDelta(None, None)
+
+            delta_ids = self.db.simple_select_list_txn(
+                txn,
+                table="state_groups_state",
+                keyvalues={"state_group": state_group},
+                retcols=("type", "state_key", "event_id"),
+            )
+
+            return _GetStateGroupDelta(
+                prev_group,
+                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+            )
+
+        return self.db.runInteraction(
+            "get_state_group_delta", _get_state_group_delta_txn
+        )
+
+    @defer.inlineCallbacks
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
+        """Returns the state groups for a given set of groups from the
+        database, filtering on types of state events.
+
+        Args:
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+        """
+        results = {}
+
+        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
+        for chunk in chunks:
+            res = yield self.db.runInteraction(
+                "_get_state_groups_from_groups",
+                self._get_state_groups_from_groups_txn,
+                chunk,
+                state_filter,
+            )
+            results.update(res)
+
+        return results
+
+    def _get_state_for_group_using_cache(self, cache, group, state_filter):
+        """Checks if group is in cache. See `_get_state_for_groups`
+
+        Args:
+            cache(DictionaryCache): the state group cache to use
+            group(int): The state group to lookup
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns 2-tuple (`state_dict`, `got_all`).
+        `got_all` is a bool indicating if we successfully retrieved all
+        requests state from the cache, if False we need to query the DB for the
+        missing state.
+        """
+        is_all, known_absent, state_dict_ids = cache.get(group)
+
+        if is_all or state_filter.is_full():
+            # Either we have everything or want everything, either way
+            # `is_all` tells us whether we've gotten everything.
+            return state_filter.filter_state(state_dict_ids), is_all
+
+        # tracks whether any of our requested types are missing from the cache
+        missing_types = False
+
+        if state_filter.has_wildcards():
+            # We don't know if we fetched all the state keys for the types in
+            # the filter that are wildcards, so we have to assume that we may
+            # have missed some.
+            missing_types = True
+        else:
+            # There aren't any wild cards, so `concrete_types()` returns the
+            # complete list of event types we're wanting.
+            for key in state_filter.concrete_types():
+                if key not in state_dict_ids and key not in known_absent:
+                    missing_types = True
+                    break
+
+        return state_filter.filter_state(state_dict_ids), not missing_types
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key
+
+        Args:
+            groups: list of state groups for which we want
+                to get the state.
+            state_filter: The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+        """
+
+        member_filter, non_member_filter = state_filter.get_member_split()
+
+        # Now we look them up in the member and non-member caches
+        (
+            non_member_state,
+            incomplete_groups_nm,
+        ) = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_cache, state_filter=non_member_filter
+        )
+
+        (
+            member_state,
+            incomplete_groups_m,
+        ) = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_members_cache, state_filter=member_filter
+        )
+
+        state = dict(non_member_state)
+        for group in groups:
+            state[group].update(member_state[group])
+
+        # Now fetch any missing groups from the database
+
+        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+        if not incomplete_groups:
+            return state
+
+        cache_sequence_nm = self._state_group_cache.sequence
+        cache_sequence_m = self._state_group_members_cache.sequence
+
+        # Help the cache hit ratio by expanding the filter a bit
+        db_state_filter = state_filter.return_expanded()
+
+        group_to_state_dict = yield self._get_state_groups_from_groups(
+            list(incomplete_groups), state_filter=db_state_filter
+        )
+
+        # Now lets update the caches
+        self._insert_into_cache(
+            group_to_state_dict,
+            db_state_filter,
+            cache_seq_num_members=cache_sequence_m,
+            cache_seq_num_non_members=cache_sequence_nm,
+        )
+
+        # And finally update the result dict, by filtering out any extra
+        # stuff we pulled out of the database.
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            # We just replace any existing entries, as we will have loaded
+            # everything we need from the database anyway.
+            state[group] = state_filter.filter_state(group_state_dict)
+
+        return state
+
+    def _get_state_for_groups_using_cache(
+        self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key, querying from a specific cache.
+
+        Args:
+            groups: list of state groups for which we want to get the state.
+            cache: the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the
+                specific members state cache.
+            state_filter: The state filter used to fetch state from the
+                database.
+
+        Returns:
+            Tuple of dict of state_group_id to state map of entries in the
+            cache, and the state group ids either missing from the cache or
+            incomplete.
+        """
+        results = {}
+        incomplete_groups = set()
+        for group in set(groups):
+            state_dict_ids, got_all = self._get_state_for_group_using_cache(
+                cache, group, state_filter
+            )
+            results[group] = state_dict_ids
+
+            if not got_all:
+                incomplete_groups.add(group)
+
+        return results, incomplete_groups
+
+    def _insert_into_cache(
+        self,
+        group_to_state_dict,
+        state_filter,
+        cache_seq_num_members,
+        cache_seq_num_non_members,
+    ):
+        """Inserts results from querying the database into the relevant cache.
+
+        Args:
+            group_to_state_dict (dict): The new entries pulled from database.
+                Map from state group to state dict
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+            cache_seq_num_members (int): Sequence number of member cache since
+                last lookup in cache
+            cache_seq_num_non_members (int): Sequence number of member cache since
+                last lookup in cache
+        """
+
+        # We need to work out which types we've fetched from the DB for the
+        # member vs non-member caches. This should be as accurate as possible,
+        # but can be an underestimate (e.g. when we have wild cards)
+
+        member_filter, non_member_filter = state_filter.get_member_split()
+        if member_filter.is_full():
+            # We fetched all member events
+            member_types = None
+        else:
+            # `concrete_types()` will only return a subset when there are wild
+            # cards in the filter, but that's fine.
+            member_types = member_filter.concrete_types()
+
+        if non_member_filter.is_full():
+            # We fetched all non member events
+            non_member_types = None
+        else:
+            non_member_types = non_member_filter.concrete_types()
+
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            state_dict_members = {}
+            state_dict_non_members = {}
+
+            for k, v in iteritems(group_state_dict):
+                if k[0] == EventTypes.Member:
+                    state_dict_members[k] = v
+                else:
+                    state_dict_non_members[k] = v
+
+            self._state_group_members_cache.update(
+                cache_seq_num_members,
+                key=group,
+                value=state_dict_members,
+                fetched_keys=member_types,
+            )
+
+            self._state_group_cache.update(
+                cache_seq_num_non_members,
+                key=group,
+                value=state_dict_non_members,
+                fetched_keys=non_member_types,
+            )
+
+    def store_state_group(
+        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+    ):
+        """Store a new set of state, returning a newly assigned state group.
+
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
+
+        Returns:
+            Deferred[int]: The state group ID
+        """
+
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
+                # AFAIK, this can never happen
+                raise Exception("current_state_ids cannot be None")
+
+            state_group = self.database_engine.get_next_state_group_id(txn)
+
+            self.db.simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
+            )
+
+            # We persist as a delta if we can, while also ensuring the chain
+            # of deltas isn't tooo long, as otherwise read performance degrades.
+            if prev_group:
+                is_in_db = self.db.simple_select_one_onecol_txn(
+                    txn,
+                    table="state_groups",
+                    keyvalues={"id": prev_group},
+                    retcol="id",
+                    allow_none=True,
+                )
+                if not is_in_db:
+                    raise Exception(
+                        "Trying to persist state with unpersisted prev_group: %r"
+                        % (prev_group,)
+                    )
+
+                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+                self.db.simple_insert_txn(
+                    txn,
+                    table="state_group_edges",
+                    values={"state_group": state_group, "prev_state_group": prev_group},
+                )
+
+                self.db.simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(delta_ids)
+                    ],
+                )
+            else:
+                self.db.simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(current_state_ids)
+                    ],
+                )
+
+            # Prefill the state group caches with this group.
+            # It's fine to use the sequence like this as the state group map
+            # is immutable. (If the map wasn't immutable then this prefill could
+            # race with another update)
+
+            current_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] == EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_members_cache.update,
+                self._state_group_members_cache.sequence,
+                key=state_group,
+                value=dict(current_member_state_ids),
+            )
+
+            current_non_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] != EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_cache.update,
+                self._state_group_cache.sequence,
+                key=state_group,
+                value=dict(current_non_member_state_ids),
+            )
+
+            return state_group
+
+        return self.db.runInteraction("store_state_group", _store_state_group_txn)
+
+    def purge_unreferenced_state_groups(
+        self, room_id: str, state_groups_to_delete
+    ) -> defer.Deferred:
+        """Deletes no longer referenced state groups and de-deltas any state
+        groups that reference them.
+
+        Args:
+            room_id: The room the state groups belong to (must all be in the
+                same room).
+            state_groups_to_delete (Collection[int]): Set of all state groups
+                to delete.
+        """
+
+        return self.db.runInteraction(
+            "purge_unreferenced_state_groups",
+            self._purge_unreferenced_state_groups,
+            room_id,
+            state_groups_to_delete,
+        )
+
+    def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+        logger.info(
+            "[purge] found %i state groups to delete", len(state_groups_to_delete)
+        )
+
+        rows = self.db.simple_select_many_txn(
+            txn,
+            table="state_group_edges",
+            column="prev_state_group",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+            retcols=("state_group",),
+        )
+
+        remaining_state_groups = {
+            row["state_group"]
+            for row in rows
+            if row["state_group"] not in state_groups_to_delete
+        }
+
+        logger.info(
+            "[purge] de-delta-ing %i remaining state groups",
+            len(remaining_state_groups),
+        )
+
+        # Now we turn the state groups that reference to-be-deleted state
+        # groups to non delta versions.
+        for sg in remaining_state_groups:
+            logger.info("[purge] de-delta-ing remaining state group %s", sg)
+            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
+            curr_state = curr_state[sg]
+
+            self.db.simple_delete_txn(
+                txn, table="state_groups_state", keyvalues={"state_group": sg}
+            )
+
+            self.db.simple_delete_txn(
+                txn, table="state_group_edges", keyvalues={"state_group": sg}
+            )
+
+            self.db.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                values=[
+                    {
+                        "state_group": sg,
+                        "room_id": room_id,
+                        "type": key[0],
+                        "state_key": key[1],
+                        "event_id": state_id,
+                    }
+                    for key, state_id in iteritems(curr_state)
+                ],
+            )
+
+        logger.info("[purge] removing redundant state groups")
+        txn.executemany(
+            "DELETE FROM state_groups_state WHERE state_group = ?",
+            ((sg,) for sg in state_groups_to_delete),
+        )
+        txn.executemany(
+            "DELETE FROM state_groups WHERE id = ?",
+            ((sg,) for sg in state_groups_to_delete),
+        )
+
+    @defer.inlineCallbacks
+    def get_previous_state_groups(self, state_groups):
+        """Fetch the previous groups of the given state groups.
+
+        Args:
+            state_groups (Iterable[int])
+
+        Returns:
+            Deferred[dict[int, int]]: mapping from state group to previous
+            state group.
+        """
+
+        rows = yield self.db.simple_select_many_batch(
+            table="state_group_edges",
+            column="prev_state_group",
+            iterable=state_groups,
+            keyvalues={},
+            retcols=("prev_state_group", "state_group"),
+            desc="get_previous_state_groups",
+        )
+
+        return {row["state_group"]: row["prev_state_group"] for row in rows}
+
+    def purge_room_state(self, room_id, state_groups_to_delete):
+        """Deletes all record of a room from state tables
+
+        Args:
+            room_id (str):
+            state_groups_to_delete (list[int]): State groups to delete
+        """
+
+        return self.db.runInteraction(
+            "purge_room_state",
+            self._purge_room_state_txn,
+            room_id,
+            state_groups_to_delete,
+        )
+
+    def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+        # first we have to delete the state groups states
+        logger.info("[purge] removing %s from state_groups_state", room_id)
+
+        self.db.simple_delete_many_txn(
+            txn,
+            table="state_groups_state",
+            column="state_group",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+        )
+
+        # ... and the state group edges
+        logger.info("[purge] removing %s from state_group_edges", room_id)
+
+        self.db.simple_delete_many_txn(
+            txn,
+            table="state_group_edges",
+            column="state_group",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+        )
+
+        # ... and the state groups
+        logger.info("[purge] removing %s from state_groups", room_id)
+
+        self.db.simple_delete_many_txn(
+            txn,
+            table="state_groups",
+            column="id",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+        )
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
new file mode 100644
index 0000000000..b112ff3df2
--- /dev/null
+++ b/synapse/storage/database.py
@@ -0,0 +1,1644 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import time
+from time import monotonic as monotonic_time
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+)
+
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+
+from prometheus_client import Histogram
+
+from twisted.enterprise import adbapi
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
+from synapse.logging.context import (
+    LoggingContext,
+    LoggingContextOrSentinel,
+    current_context,
+    make_deferred_yieldable,
+)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
+
+logger = logging.getLogger(__name__)
+
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
+
+sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
+
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
+
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+
+
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+    "user_ips": "user_ips_device_unique_index",
+    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+    "event_search": "event_search_event_id_idx",
+}
+
+
+def make_pool(
+    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> adbapi.ConnectionPool:
+    """Get the connection pool for the database.
+    """
+
+    return adbapi.ConnectionPool(
+        db_config.config["name"],
+        cp_reactor=reactor,
+        cp_openfun=engine.on_new_connection,
+        **db_config.config.get("args", {})
+    )
+
+
+def make_conn(
+    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
+    """Make a new connection to the database and return it.
+
+    Returns:
+        Connection
+    """
+
+    db_params = {
+        k: v
+        for k, v in db_config.config.get("args", {}).items()
+        if not k.startswith("cp_")
+    }
+    db_conn = engine.module.connect(**db_params)
+    engine.on_new_connection(db_conn)
+    return db_conn
+
+
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
+    """An object that almost-transparently proxies for the 'txn' object
+    passed to the constructor. Adds logging and metrics to the .execute()
+    method.
+
+    Args:
+        txn: The database transcation object to wrap.
+        name: The name of this transactions for logging.
+        database_engine
+        after_callbacks: A list that callbacks will be appended to
+            that have been added by `call_after` which should be run on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
+        exception_callbacks: A list that callbacks will be appended
+            to that have been added by `call_on_exception` which should be run
+            if transaction ends with an error. None indicates that no callbacks
+            should be allowed to be scheduled to run.
+    """
+
+    __slots__ = [
+        "txn",
+        "name",
+        "database_engine",
+        "after_callbacks",
+        "exception_callbacks",
+    ]
+
+    def __init__(
+        self,
+        txn: Cursor,
+        name: str,
+        database_engine: BaseDatabaseEngine,
+        after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        exception_callbacks: Optional[List[_CallbackListEntry]] = None,
+    ):
+        self.txn = txn
+        self.name = name
+        self.database_engine = database_engine
+        self.after_callbacks = after_callbacks
+        self.exception_callbacks = exception_callbacks
+
+    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+        """Call the given callback on the main twisted thread after the
+        transaction has finished. Used to invalidate the caches on the
+        correct thread.
+        """
+        # if self.after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.after_callbacks is not None
+        self.after_callbacks.append((callback, args, kwargs))
+
+    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+        # if self.exception_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.exception_callbacks is not None
+        self.exception_callbacks.append((callback, args, kwargs))
+
+    def fetchall(self) -> List[Tuple]:
+        return self.txn.fetchall()
+
+    def fetchone(self) -> Tuple:
+        return self.txn.fetchone()
+
+    def __iter__(self) -> Iterator[Tuple]:
+        return self.txn.__iter__()
+
+    @property
+    def rowcount(self) -> int:
+        return self.txn.rowcount
+
+    @property
+    def description(self) -> Any:
+        return self.txn.description
+
+    def execute_batch(self, sql, args):
+        if isinstance(self.database_engine, PostgresEngine):
+            from psycopg2.extras import execute_batch  # type: ignore
+
+            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+        else:
+            for val in args:
+                self.execute(sql, val)
+
+    def execute(self, sql: str, *args: Any):
+        self._do_execute(self.txn.execute, sql, *args)
+
+    def executemany(self, sql: str, *args: Any):
+        self._do_execute(self.txn.executemany, sql, *args)
+
+    def _make_sql_one_line(self, sql: str) -> str:
+        "Strip newlines out of SQL so that the loggers in the DB are on one line"
+        return " ".join(line.strip() for line in sql.splitlines() if line.strip())
+
+    def _do_execute(self, func, sql, *args):
+        sql = self._make_sql_one_line(sql)
+
+        # TODO(paul): Maybe use 'info' and 'debug' for values?
+        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+
+        sql = self.database_engine.convert_param_style(sql)
+        if args:
+            try:
+                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
+            except Exception:
+                # Don't let logging failures stop SQL from working
+                pass
+
+        start = time.time()
+
+        try:
+            return func(sql, *args)
+        except Exception as e:
+            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+            raise
+        finally:
+            secs = time.time() - start
+            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+            sql_query_timer.labels(sql.split()[0]).observe(secs)
+
+    def close(self):
+        self.txn.close()
+
+
+class PerformanceCounters(object):
+    def __init__(self):
+        self.current_counters = {}
+        self.previous_counters = {}
+
+    def update(self, key, duration_secs):
+        count, cum_time = self.current_counters.get(key, (0, 0))
+        count += 1
+        cum_time += duration_secs
+        self.current_counters[key] = (count, cum_time)
+
+    def interval(self, interval_duration_secs, limit=3):
+        counters = []
+        for name, (count, cum_time) in iteritems(self.current_counters):
+            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
+            counters.append(
+                (
+                    (cum_time - prev_time) / interval_duration_secs,
+                    count - prev_count,
+                    name,
+                )
+            )
+
+        self.previous_counters = dict(self.current_counters)
+
+        counters.sort(reverse=True)
+
+        top_n_counters = ", ".join(
+            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
+            for ratio, count, name in counters[:limit]
+        )
+
+        return top_n_counters
+
+
+class Database(object):
+    """Wraps a single physical database and connection pool.
+
+    A single database may be used by multiple data stores.
+    """
+
+    _TXN_ID = 0
+
+    def __init__(
+        self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    ):
+        self.hs = hs
+        self._clock = hs.get_clock()
+        self._database_config = database_config
+        self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
+
+        self.updates = BackgroundUpdater(hs, self)
+
+        self._previous_txn_total_time = 0.0
+        self._current_txn_total_time = 0.0
+        self._previous_loop_ts = 0.0
+
+        # TODO(paul): These can eventually be removed once the metrics code
+        #   is running in mainline, and we have some nice monitoring frontends
+        #   to watch it
+        self._txn_perf_counters = PerformanceCounters()
+
+        self.engine = engine
+
+        # A set of tables that are not safe to use native upserts in.
+        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+        # We add the user_directory_search table to the blacklist on SQLite
+        # because the existing search table does not have an index, making it
+        # unsafe to use native upserts.
+        if isinstance(self.engine, Sqlite3Engine):
+            self._unsafe_to_upsert_tables.add("user_directory_search")
+
+        if self.engine.can_native_upsert:
+            # Check ASAP (and then later, every 1s) to see if we have finished
+            # background updates of tables that aren't safe to update.
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert,
+            )
+
+    def is_running(self):
+        """Is the database pool currently running
+        """
+        return self._db_pool.running
+
+    @defer.inlineCallbacks
+    def _check_safe_to_upsert(self):
+        """
+        Is it safe to use native UPSERT?
+
+        If there are background updates, we will need to wait, as they may be
+        the addition of indexes that set the UNIQUE constraint that we require.
+
+        If the background updates have not completed, wait 15 sec and check again.
+        """
+        updates = yield self.simple_select_list(
+            "background_updates",
+            keyvalues=None,
+            retcols=["update_name"],
+            desc="check_background_updates",
+        )
+        updates = [x["update_name"] for x in updates]
+
+        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+            if update_name not in updates:
+                logger.debug("Now safe to upsert in %s", table)
+                self._unsafe_to_upsert_tables.discard(table)
+
+        # If there's any updates still running, reschedule to run.
+        if updates:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert,
+            )
+
+    def start_profiling(self):
+        self._previous_loop_ts = monotonic_time()
+
+        def loop():
+            curr = self._current_txn_total_time
+            prev = self._previous_txn_total_time
+            self._previous_txn_total_time = curr
+
+            time_now = monotonic_time()
+            time_then = self._previous_loop_ts
+            self._previous_loop_ts = time_now
+
+            duration = time_now - time_then
+            ratio = (curr - prev) / duration
+
+            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
+
+            perf_logger.debug(
+                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
+            )
+
+        self._clock.looping_call(loop, 10000)
+
+    def new_transaction(
+        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+    ):
+        start = monotonic_time()
+        txn_id = self._TXN_ID
+
+        # We don't really need these to be unique, so lets stop it from
+        # growing really large.
+        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
+
+        name = "%s-%x" % (desc, txn_id)
+
+        transaction_logger.debug("[TXN START] {%s}", name)
+
+        try:
+            i = 0
+            N = 5
+            while True:
+                cursor = LoggingTransaction(
+                    conn.cursor(),
+                    name,
+                    self.engine,
+                    after_callbacks,
+                    exception_callbacks,
+                )
+                try:
+                    r = func(cursor, *args, **kwargs)
+                    conn.commit()
+                    return r
+                except self.engine.module.OperationalError as e:
+                    # This can happen if the database disappears mid
+                    # transaction.
+                    logger.warning(
+                        "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
+                    )
+                    if i < N:
+                        i += 1
+                        try:
+                            conn.rollback()
+                        except self.engine.module.Error as e1:
+                            logger.warning("[TXN EROLL] {%s} %s", name, e1)
+                        continue
+                    raise
+                except self.engine.module.DatabaseError as e:
+                    if self.engine.is_deadlock(e):
+                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        if i < N:
+                            i += 1
+                            try:
+                                conn.rollback()
+                            except self.engine.module.Error as e1:
+                                logger.warning(
+                                    "[TXN EROLL] {%s} %s", name, e1,
+                                )
+                            continue
+                    raise
+                finally:
+                    # we're either about to retry with a new cursor, or we're about to
+                    # release the connection. Once we release the connection, it could
+                    # get used for another query, which might do a conn.rollback().
+                    #
+                    # In the latter case, even though that probably wouldn't affect the
+                    # results of this transaction, python's sqlite will reset all
+                    # statements on the connection [1], which will make our cursor
+                    # invalid [2].
+                    #
+                    # In any case, continuing to read rows after commit()ing seems
+                    # dubious from the PoV of ACID transactional semantics
+                    # (sqlite explicitly says that once you commit, you may see rows
+                    # from subsequent updates.)
+                    #
+                    # In psycopg2, cursors are essentially a client-side fabrication -
+                    # all the data is transferred to the client side when the statement
+                    # finishes executing - so in theory we could go on streaming results
+                    # from the cursor, but attempting to do so would make us
+                    # incompatible with sqlite, so let's make sure we're not doing that
+                    # by closing the cursor.
+                    #
+                    # (*named* cursors in psycopg2 are different and are proper server-
+                    # side things, but (a) we don't use them and (b) they are implicitly
+                    # closed by ending the transaction anyway.)
+                    #
+                    # In short, if we haven't finished with the cursor yet, that's a
+                    # problem waiting to bite us.
+                    #
+                    # TL;DR: we're done with the cursor, so we can close it.
+                    #
+                    # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+                    # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+                    cursor.close()
+        except Exception as e:
+            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            raise
+        finally:
+            end = monotonic_time()
+            duration = end - start
+
+            current_context().add_database_transaction(duration)
+
+            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
+
+            self._current_txn_total_time += duration
+            self._txn_perf_counters.update(desc, duration)
+            sql_txn_timer.labels(desc).observe(duration)
+
+    @defer.inlineCallbacks
+    def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+        """Starts a transaction on the database and runs a given function
+
+        Arguments:
+            desc: description of the transaction, for logging and metrics
+            func: callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
+
+            args: positional args to pass to `func`
+            kwargs: named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        after_callbacks = []  # type: List[_CallbackListEntry]
+        exception_callbacks = []  # type: List[_CallbackListEntry]
+
+        if not current_context():
+            logger.warning("Starting db txn '%s' from sentinel context", desc)
+
+        try:
+            result = yield self.runWithConnection(
+                self.new_transaction,
+                desc,
+                after_callbacks,
+                exception_callbacks,
+                func,
+                *args,
+                **kwargs
+            )
+
+            for after_callback, after_args, after_kwargs in after_callbacks:
+                after_callback(*after_args, **after_kwargs)
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            for after_callback, after_args, after_kwargs in exception_callbacks:
+                after_callback(*after_args, **after_kwargs)
+            raise
+
+        return result
+
+    @defer.inlineCallbacks
+    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+        """Wraps the .runWithConnection() method on the underlying db_pool.
+
+        Arguments:
+            func: callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args: positional args to pass to `func`
+            kwargs: named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
+        if not parent_context:
+            logger.warning(
+                "Starting db connection from sentinel context: metrics will be lost"
+            )
+            parent_context = None
+
+        start_time = monotonic_time()
+
+        def inner_func(conn, *args, **kwargs):
+            with LoggingContext("runWithConnection", parent_context) as context:
+                sched_duration_sec = monotonic_time() - start_time
+                sql_scheduling_timer.observe(sched_duration_sec)
+                context.add_database_scheduled(sched_duration_sec)
+
+                if self.engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
+
+                return func(conn, *args, **kwargs)
+
+        result = yield make_deferred_yieldable(
+            self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+        )
+
+        return result
+
+    @staticmethod
+    def cursor_to_dict(cursor):
+        """Converts a SQL cursor into an list of dicts.
+
+        Args:
+            cursor : The DBAPI cursor which has executed a query.
+        Returns:
+            A list of dicts where the key is the column header.
+        """
+        col_headers = [intern(str(column[0])) for column in cursor.description]
+        results = [dict(zip(col_headers, row)) for row in cursor]
+        return results
+
+    def execute(self, desc, decoder, query, *args):
+        """Runs a single query for a result set.
+
+        Args:
+            decoder - The function which can resolve the cursor results to
+                something meaningful.
+            query - The query string to execute
+            *args - Query args.
+        Returns:
+            The result of decoder(results)
+        """
+
+        def interaction(txn):
+            txn.execute(query, args)
+            if decoder:
+                return decoder(txn)
+            else:
+                return txn.fetchall()
+
+        return self.runInteraction(desc, interaction)
+
+    # "Simple" SQL API methods that operate on a single table with no JOINs,
+    # no complex WHERE clauses, just a dict of values for columns.
+
+    @defer.inlineCallbacks
+    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+        """Executes an INSERT query on the named table.
+
+        Args:
+            table : string giving the table name
+            values : dict of new column names and values for them
+            or_ignore : bool stating whether an exception should be raised
+                when a conflicting row already exists. If True, False will be
+                returned by the function instead
+            desc : string giving a description of the transaction
+
+        Returns:
+            bool: Whether the row was inserted or not. Only useful when
+            `or_ignore` is True
+        """
+        try:
+            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+        except self.engine.module.IntegrityError:
+            # We have to do or_ignore flag at this layer, since we can't reuse
+            # a cursor after we receive an error from the db.
+            if not or_ignore:
+                raise
+            return False
+        return True
+
+    @staticmethod
+    def simple_insert_txn(txn, table, values):
+        keys, vals = zip(*values.items())
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys),
+            ", ".join("?" for _ in keys),
+        )
+
+        txn.execute(sql, vals)
+
+    def simple_insert_many(self, table, values, desc):
+        return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+
+    @staticmethod
+    def simple_insert_many_txn(txn, table, values):
+        if not values:
+            return
+
+        # This is a *slight* abomination to get a list of tuples of key names
+        # and a list of tuples of value names.
+        #
+        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+        #
+        # The sort is to ensure that we don't rely on dictionary iteration
+        # order.
+        keys, vals = zip(
+            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+        )
+
+        for k in keys:
+            if k != keys[0]:
+                raise RuntimeError("All items must have the same keys")
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys[0]),
+            ", ".join("?" for _ in keys[0]),
+        )
+
+        txn.executemany(sql, vals)
+
+    @defer.inlineCallbacks
+    def simple_upsert(
+        self,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        desc="simple_upsert",
+        lock=True,
+    ):
+        """
+
+        `lock` should generally be set to True (the default), but can be set
+        to False if either of the following are true:
+
+        * there is a UNIQUE INDEX on the key columns. In this case a conflict
+          will cause an IntegrityError in which case this function will retry
+          the update.
+
+        * we somehow know that we are the only thread which will be updating
+          this table.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key columns and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            Deferred(None or bool): Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        attempts = 0
+        while True:
+            try:
+                result = yield self.runInteraction(
+                    desc,
+                    self.simple_upsert_txn,
+                    table,
+                    keyvalues,
+                    values,
+                    insertion_values,
+                    lock=lock,
+                )
+                return result
+            except self.engine.module.IntegrityError as e:
+                attempts += 1
+                if attempts >= 5:
+                    # don't retry forever, because things other than races
+                    # can cause IntegrityErrors
+                    raise
+
+                # presumably we raced with another transaction: let's retry.
+                logger.warning(
+                    "IntegrityError when upserting into %s; retrying: %s", table, e
+                )
+
+    def simple_upsert_txn(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Pick the UPSERT method which works best on the platform. Either the
+        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+        Args:
+            txn: The transaction to use.
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            None or bool: Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+            return self.simple_upsert_txn_native_upsert(
+                txn, table, keyvalues, values, insertion_values=insertion_values
+            )
+        else:
+            return self.simple_upsert_txn_emulated(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+                lock=lock,
+            )
+
+    def simple_upsert_txn_emulated(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            bool: Return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        # We need to lock the table :(, unless we're *really* careful
+        if lock:
+            self.engine.lock_table(txn, table)
+
+        def _getwhere(key):
+            # If the value we're passing in is None (aka NULL), we need to use
+            # IS, not =, as NULL = NULL equals NULL (False).
+            if keyvalues[key] is None:
+                return "%s IS ?" % (key,)
+            else:
+                return "%s = ?" % (key,)
+
+        if not values:
+            # If `values` is empty, then all of the values we care about are in
+            # the unique key, so there is nothing to UPDATE. We can just do a
+            # SELECT instead to see if it exists.
+            sql = "SELECT 1 FROM %s WHERE %s" % (
+                table,
+                " AND ".join(_getwhere(k) for k in keyvalues),
+            )
+            sqlargs = list(keyvalues.values())
+            txn.execute(sql, sqlargs)
+            if txn.fetchall():
+                # We have an existing record.
+                return False
+        else:
+            # First try to update.
+            sql = "UPDATE %s SET %s WHERE %s" % (
+                table,
+                ", ".join("%s = ?" % (k,) for k in values),
+                " AND ".join(_getwhere(k) for k in keyvalues),
+            )
+            sqlargs = list(values.values()) + list(keyvalues.values())
+
+            txn.execute(sql, sqlargs)
+            if txn.rowcount > 0:
+                # successfully updated at least one row.
+                return False
+
+        # We didn't find any existing rows, so insert a new one
+        allvalues = {}  # type: Dict[str, Any]
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
+
+        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+        )
+        txn.execute(sql, list(allvalues.values()))
+        # successfully inserted
+        return True
+
+    def simple_upsert_txn_native_upsert(
+        self, txn, table, keyvalues, values, insertion_values={}
+    ):
+        """
+        Use the native UPSERT functionality in recent PostgreSQL versions.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+        Returns:
+            None
+        """
+        allvalues = {}  # type: Dict[str, Any]
+        allvalues.update(keyvalues)
+        allvalues.update(insertion_values)
+
+        if not values:
+            latter = "NOTHING"
+        else:
+            allvalues.update(values)
+            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+
+        sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+            ", ".join(k for k in keyvalues),
+            latter,
+        )
+        txn.execute(sql, list(allvalues.values()))
+
+    def simple_upsert_many_txn(
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
+        """
+        Upsert, many times.
+
+        Args:
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
+        """
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+            return self.simple_upsert_many_txn_native_upsert(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+        else:
+            return self.simple_upsert_many_txn_emulated(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+
+    def simple_upsert_many_txn_emulated(
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Iterable[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
+        """
+        Upsert, many times, but without native UPSERT support or batching.
+
+        Args:
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
+        """
+        # No value columns, therefore make a blank list so that the following
+        # zip() works correctly.
+        if not value_names:
+            value_values = [() for x in range(len(key_values))]
+
+        for keyv, valv in zip(key_values, value_values):
+            _keys = {x: y for x, y in zip(key_names, keyv)}
+            _vals = {x: y for x, y in zip(value_names, valv)}
+
+            self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+    def simple_upsert_many_txn_native_upsert(
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[Any]],
+    ) -> None:
+        """
+        Upsert, many times, using batching where possible.
+
+        Args:
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
+        """
+        allnames = []  # type: List[str]
+        allnames.extend(key_names)
+        allnames.extend(value_names)
+
+        if not value_names:
+            # No value columns, therefore make a blank list so that the
+            # following zip() works correctly.
+            latter = "NOTHING"
+            value_values = [() for x in range(len(key_values))]
+        else:
+            latter = "UPDATE SET " + ", ".join(
+                k + "=EXCLUDED." + k for k in value_names
+            )
+
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+            table,
+            ", ".join(k for k in allnames),
+            ", ".join("?" for _ in allnames),
+            ", ".join(key_names),
+            latter,
+        )
+
+        args = []
+
+        for x, y in zip(key_values, value_values):
+            args.append(tuple(x) + tuple(y))
+
+        return txn.execute_batch(sql, args)
+
+    def simple_select_one(
+        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
+    ):
+        """Executes a SELECT query on the named table, which is expected to
+        return a single row, returning multiple columns from it.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            retcols : list of strings giving the names of the columns to return
+
+            allow_none : If true, return None instead of failing if the SELECT
+              statement returns no rows
+        """
+        return self.runInteraction(
+            desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+        )
+
+    def simple_select_one_onecol(
+        self,
+        table,
+        keyvalues,
+        retcol,
+        allow_none=False,
+        desc="simple_select_one_onecol",
+    ):
+        """Executes a SELECT query on the named table, which is expected to
+        return a single row, returning a single column from it.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            retcol : string giving the name of the column to return
+        """
+        return self.runInteraction(
+            desc,
+            self.simple_select_one_onecol_txn,
+            table,
+            keyvalues,
+            retcol,
+            allow_none=allow_none,
+        )
+
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls, txn, table, keyvalues, retcol, allow_none=False
+    ):
+        ret = cls.simple_select_onecol_txn(
+            txn, table=table, keyvalues=keyvalues, retcol=retcol
+        )
+
+        if ret:
+            return ret[0]
+        else:
+            if allow_none:
+                return None
+            else:
+                raise StoreError(404, "No row found")
+
+    @staticmethod
+    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
+
+        if keyvalues:
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+            txn.execute(sql, list(keyvalues.values()))
+        else:
+            txn.execute(sql)
+
+        return [r[0] for r in txn]
+
+    def simple_select_onecol(
+        self, table, keyvalues, retcol, desc="simple_select_onecol"
+    ):
+        """Executes a SELECT query on the named table, which returns a list
+        comprising of the values of the named column from the selected rows.
+
+        Args:
+            table (str): table name
+            keyvalues (dict|None): column names and values to select the rows with
+            retcol (str): column whos value we wish to retrieve.
+
+        Returns:
+            Deferred: Results in a list
+        """
+        return self.runInteraction(
+            desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+        )
+
+    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            keyvalues (dict[str, Any] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc, self.simple_select_list_txn, table, keyvalues, retcols
+        )
+
+    @classmethod
+    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+        """
+        if keyvalues:
+            sql = "SELECT %s FROM %s WHERE %s" % (
+                ", ".join(retcols),
+                table,
+                " AND ".join("%s = ?" % (k,) for k in keyvalues),
+            )
+            txn.execute(sql, list(keyvalues.values()))
+        else:
+            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+            txn.execute(sql)
+
+        return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def simple_select_many_batch(
+        self,
+        table,
+        column,
+        iterable,
+        retcols,
+        keyvalues={},
+        desc="simple_select_many_batch",
+        batch_size=100,
+    ):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        results = []  # type: List[Dict[str, Any]]
+
+        if not iterable:
+            return results
+
+        # iterables can not be sliced, so convert it to a list first
+        it_list = list(iterable)
+
+        chunks = [
+            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
+        ]
+        for chunk in chunks:
+            rows = yield self.runInteraction(
+                desc,
+                self.simple_select_many_txn,
+                table,
+                column,
+                chunk,
+                keyvalues,
+                retcols,
+            )
+
+            results.extend(rows)
+
+        return results
+
+    @classmethod
+    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        if not iterable:
+            return []
+
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
+
+        for key, value in iteritems(keyvalues):
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join(clauses),
+        )
+
+        txn.execute(sql, values)
+        return cls.cursor_to_dict(txn)
+
+    def simple_update(self, table, keyvalues, updatevalues, desc):
+        return self.runInteraction(
+            desc, self.simple_update_txn, table, keyvalues, updatevalues
+        )
+
+    @staticmethod
+    def simple_update_txn(txn, table, keyvalues, updatevalues):
+        if keyvalues:
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+        else:
+            where = ""
+
+        update_sql = "UPDATE %s SET %s %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in updatevalues),
+            where,
+        )
+
+        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
+
+        return txn.rowcount
+
+    def simple_update_one(
+        self, table, keyvalues, updatevalues, desc="simple_update_one"
+    ):
+        """Executes an UPDATE query on the named table, setting new values for
+        columns in a row matching the key values.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            updatevalues : dict giving column names and values to update
+            retcols : optional list of column names to return
+
+        If present, retcols gives a list of column names on which to perform
+        a SELECT statement *before* performing the UPDATE statement. The values
+        of these will be returned in a dict.
+
+        These are performed within the same transaction, allowing an atomic
+        get-and-set.  This can be used to implement compare-and-set by putting
+        the update column in the 'keyvalues' dict as well.
+        """
+        return self.runInteraction(
+            desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+        )
+
+    @classmethod
+    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+        rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
+
+        if rowcount == 0:
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+    @staticmethod
+    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+        select_sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(select_sql, list(keyvalues.values()))
+        row = txn.fetchone()
+
+        if not row:
+            if allow_none:
+                return None
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+        return dict(zip(retcols, row))
+
+    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+        return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+
+    @staticmethod
+    def simple_delete_one_txn(txn, table, keyvalues):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(sql, list(keyvalues.values()))
+        if txn.rowcount == 0:
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+    def simple_delete(self, table, keyvalues, desc):
+        return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+
+    @staticmethod
+    def simple_delete_txn(txn, table, keyvalues):
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(sql, list(keyvalues.values()))
+        return txn.rowcount
+
+    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+        return self.runInteraction(
+            desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+        )
+
+    @staticmethod
+    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+        """Executes a DELETE query on the named table.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+
+        Returns:
+            int: Number rows deleted
+        """
+        if not iterable:
+            return 0
+
+        sql = "DELETE FROM %s" % table
+
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
+
+        for key, value in iteritems(keyvalues):
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+        txn.execute(sql, values)
+
+        return txn.rowcount
+
+    def get_cache_dict(
+        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+    ):
+        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+        # It doesn't really matter how many we get, the StreamChangeCache will
+        # do the right thing to ensure it respects the max size of cache.
+        sql = (
+            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+            " WHERE %(stream)s > ? - %(limit)s"
+            " GROUP BY %(entity)s"
+        ) % {
+            "table": table,
+            "entity": entity_column,
+            "stream": stream_column,
+            "limit": limit,
+        }
+
+        sql = self.engine.convert_param_style(sql)
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (int(max_value),))
+
+        cache = {row[0]: int(row[1]) for row in txn}
+
+        txn.close()
+
+        if cache:
+            min_val = min(itervalues(cache))
+        else:
+            min_val = max_value
+
+        return cache, min_val
+
+    def simple_select_list_paginate(
+        self,
+        table,
+        orderby,
+        start,
+        limit,
+        retcols,
+        filters=None,
+        keyvalues=None,
+        order_direction="ASC",
+        desc="simple_select_list_paginate",
+    ):
+        """
+        Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            filters (dict[str, T] | None):
+                column names and values to filter the rows with, or None to not
+                apply a WHERE ? LIKE ? clause.
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
+            retcols (iterable[str]): the names of the columns to return
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc,
+            self.simple_select_list_paginate_txn,
+            table,
+            orderby,
+            start,
+            limit,
+            retcols,
+            filters=filters,
+            keyvalues=keyvalues,
+            order_direction=order_direction,
+        )
+
+    @classmethod
+    def simple_select_list_paginate_txn(
+        cls,
+        txn,
+        table,
+        orderby,
+        start,
+        limit,
+        retcols,
+        filters=None,
+        keyvalues=None,
+        order_direction="ASC",
+    ):
+        """
+        Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
+        select attributes with exact matches. All constraints are joined together
+        using 'AND'.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
+            retcols (iterable[str]): the names of the columns to return
+            filters (dict[str, T] | None):
+                column names and values to filter the rows with, or None to not
+                apply a WHERE ? LIKE ? clause.
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        if order_direction not in ["ASC", "DESC"]:
+            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
+        where_clause = "WHERE " if filters or keyvalues else ""
+        arg_list = []  # type: List[Any]
+        if filters:
+            where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
+            arg_list += list(filters.values())
+        where_clause += " AND " if filters and keyvalues else ""
+        if keyvalues:
+            where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
+            arg_list += list(keyvalues.values())
+
+        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+            ", ".join(retcols),
+            table,
+            where_clause,
+            orderby,
+            order_direction,
+        )
+        txn.execute(sql, arg_list + [limit, start])
+
+        return cls.cursor_to_dict(txn)
+
+    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+
+        return self.runInteraction(
+            desc, self.simple_search_list_txn, table, term, col, retcols
+        )
+
+    @classmethod
+    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+        if term:
+            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
+            termvalues = ["%%" + term + "%%"]
+            txn.execute(sql, termvalues)
+        else:
+            return 0
+
+        return cls.cursor_to_dict(txn)
+
+
+def make_in_list_sql_clause(
+    database_engine, column: str, iterable: Iterable
+) -> Tuple[str, list]:
+    """Returns an SQL clause that checks the given column is in the iterable.
+
+    On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+    it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+    using the `ANY` form on postgres means that it views queries with
+    different length iterables as the same, helping the query stats.
+
+    Args:
+        database_engine
+        column: Name of the column
+        iterable: The values to check the column against.
+
+    Returns:
+        A tuple of SQL query and the args
+    """
+
+    if database_engine.supports_using_any_list:
+        # This should hopefully be faster, but also makes postgres query
+        # stats easier to understand.
+        return "%s = ANY(?)" % (column,), [list(iterable)]
+    else:
+        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+
+
+KV = TypeVar("KV")
+
+
+def make_tuple_comparison_clause(
+    database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
+) -> Tuple[str, List[KV]]:
+    """Returns a tuple comparison SQL clause
+
+    Depending what the SQL engine supports, builds a SQL clause that looks like either
+    "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
+
+    Args:
+        database_engine
+        keys: A set of (column, value) pairs to be compared.
+
+    Returns:
+        A tuple of SQL query and the args
+    """
+    if database_engine.supports_tuple_comparison:
+        return (
+            "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
+            [k[1] for k in keys],
+        )
+
+    # we want to build a clause
+    #    (a > ?) OR
+    #    (a == ? AND b > ?) OR
+    #    (a == ? AND b == ? AND c > ?)
+    #    ...
+    #    (a == ? AND b == ? AND ... AND z > ?)
+    #
+    # or, equivalently:
+    #
+    #  (a > ? OR (a == ? AND
+    #    (b > ? OR (b == ? AND
+    #      ...
+    #        (y > ? OR (y == ? AND
+    #          z > ?
+    #        ))
+    #      ...
+    #    ))
+    #  ))
+    #
+    # which itself is equivalent to (and apparently easier for the query optimiser):
+    #
+    #  (a >= ? AND (a > ? OR
+    #    (b >= ? AND (b > ? OR
+    #      ...
+    #        (y >= ? AND (y > ? OR
+    #          z > ?
+    #        ))
+    #      ...
+    #    ))
+    #  ))
+    #
+    #
+
+    clause = ""
+    args = []  # type: List[KV]
+    for k, v in keys[:-1]:
+        clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
+        args.extend([v, v])
+
+    (k, v) = keys[-1]
+    clause += "%s > ?" % (k,)
+    args.append(v)
+
+    clause += "))" * (len(keys) - 1)
+    return clause, args
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
deleted file mode 100644
index 33e3a84933..0000000000
--- a/synapse/storage/end_to_end_keys.py
+++ /dev/null
@@ -1,313 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from six import iteritems
-
-from canonicaljson import encode_canonical_json
-
-from twisted.internet import defer
-
-from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.util.caches.descriptors import cached
-
-from ._base import SQLBaseStore, db_to_json
-
-
-class EndToEndKeyWorkerStore(SQLBaseStore):
-    @trace
-    @defer.inlineCallbacks
-    def get_e2e_device_keys(
-        self, query_list, include_all_devices=False, include_deleted_devices=False
-    ):
-        """Fetch a list of device keys.
-        Args:
-            query_list(list): List of pairs of user_ids and device_ids.
-            include_all_devices (bool): whether to include entries for devices
-                that don't have device keys
-            include_deleted_devices (bool): whether to include null entries for
-                devices which no longer exist (but were in the query_list).
-                This option only takes effect if include_all_devices is true.
-        Returns:
-            Dict mapping from user-id to dict mapping from device_id to
-            dict containing "key_json", "device_display_name".
-        """
-        set_tag("query_list", query_list)
-        if not query_list:
-            return {}
-
-        results = yield self.runInteraction(
-            "get_e2e_device_keys",
-            self._get_e2e_device_keys_txn,
-            query_list,
-            include_all_devices,
-            include_deleted_devices,
-        )
-
-        for user_id, device_keys in iteritems(results):
-            for device_id, device_info in iteritems(device_keys):
-                device_info["keys"] = db_to_json(device_info.pop("key_json"))
-
-        return results
-
-    @trace
-    def _get_e2e_device_keys_txn(
-        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
-    ):
-        set_tag("include_all_devices", include_all_devices)
-        set_tag("include_deleted_devices", include_deleted_devices)
-
-        query_clauses = []
-        query_params = []
-
-        if include_all_devices is False:
-            include_deleted_devices = False
-
-        if include_deleted_devices:
-            deleted_devices = set(query_list)
-
-        for (user_id, device_id) in query_list:
-            query_clause = "user_id = ?"
-            query_params.append(user_id)
-
-            if device_id is not None:
-                query_clause += " AND device_id = ?"
-                query_params.append(device_id)
-
-            query_clauses.append(query_clause)
-
-        sql = (
-            "SELECT user_id, device_id, "
-            "    d.display_name AS device_display_name, "
-            "    k.key_json"
-            " FROM devices d"
-            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
-            " WHERE %s"
-        ) % (
-            "LEFT" if include_all_devices else "INNER",
-            " OR ".join("(" + q + ")" for q in query_clauses),
-        )
-
-        txn.execute(sql, query_params)
-        rows = self.cursor_to_dict(txn)
-
-        result = {}
-        for row in rows:
-            if include_deleted_devices:
-                deleted_devices.remove((row["user_id"], row["device_id"]))
-            result.setdefault(row["user_id"], {})[row["device_id"]] = row
-
-        if include_deleted_devices:
-            for user_id, device_id in deleted_devices:
-                result.setdefault(user_id, {})[device_id] = None
-
-        log_kv(result)
-        return result
-
-    @defer.inlineCallbacks
-    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
-        """Retrieve a number of one-time keys for a user
-
-        Args:
-            user_id(str): id of user to get keys for
-            device_id(str): id of device to get keys for
-            key_ids(list[str]): list of key ids (excluding algorithm) to
-                retrieve
-
-        Returns:
-            deferred resolving to Dict[(str, str), str]: map from (algorithm,
-            key_id) to json string for key
-        """
-
-        rows = yield self._simple_select_many_batch(
-            table="e2e_one_time_keys_json",
-            column="key_id",
-            iterable=key_ids,
-            retcols=("algorithm", "key_id", "key_json"),
-            keyvalues={"user_id": user_id, "device_id": device_id},
-            desc="add_e2e_one_time_keys_check",
-        )
-        result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
-        log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
-        return result
-
-    @defer.inlineCallbacks
-    def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
-        """Insert some new one time keys for a device. Errors if any of the
-        keys already exist.
-
-        Args:
-            user_id(str): id of user to get keys for
-            device_id(str): id of device to get keys for
-            time_now(long): insertion time to record (ms since epoch)
-            new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
-                (algorithm, key_id, key json)
-        """
-
-        def _add_e2e_one_time_keys(txn):
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("new_keys", new_keys)
-            # We are protected from race between lookup and insertion due to
-            # a unique constraint. If there is a race of two calls to
-            # `add_e2e_one_time_keys` then they'll conflict and we will only
-            # insert one set.
-            self._simple_insert_many_txn(
-                txn,
-                table="e2e_one_time_keys_json",
-                values=[
-                    {
-                        "user_id": user_id,
-                        "device_id": device_id,
-                        "algorithm": algorithm,
-                        "key_id": key_id,
-                        "ts_added_ms": time_now,
-                        "key_json": json_bytes,
-                    }
-                    for algorithm, key_id, json_bytes in new_keys
-                ],
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
-        yield self.runInteraction(
-            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
-        )
-
-    @cached(max_entries=10000)
-    def count_e2e_one_time_keys(self, user_id, device_id):
-        """ Count the number of one time keys the server has for a device
-        Returns:
-            Dict mapping from algorithm to number of keys for that algorithm.
-        """
-
-        def _count_e2e_one_time_keys(txn):
-            sql = (
-                "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
-                " WHERE user_id = ? AND device_id = ?"
-                " GROUP BY algorithm"
-            )
-            txn.execute(sql, (user_id, device_id))
-            result = {}
-            for algorithm, key_count in txn:
-                result[algorithm] = key_count
-            return result
-
-        return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
-
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
-        """Stores device keys for a device. Returns whether there was a change
-        or the keys were already in the database.
-        """
-
-        def _set_e2e_device_keys_txn(txn):
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
-
-            old_key_json = self._simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
-
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
-
-            self._simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
-
-        return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
-
-    def claim_e2e_one_time_keys(self, query_list):
-        """Take a list of one time keys out of the database"""
-
-        @trace
-        def _claim_e2e_one_time_keys(txn):
-            sql = (
-                "SELECT key_id, key_json FROM e2e_one_time_keys_json"
-                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
-                " LIMIT 1"
-            )
-            result = {}
-            delete = []
-            for user_id, device_id, algorithm in query_list:
-                user_result = result.setdefault(user_id, {})
-                device_result = user_result.setdefault(device_id, {})
-                txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn:
-                    device_result[algorithm + ":" + key_id] = key_json
-                    delete.append((user_id, device_id, algorithm, key_id))
-            sql = (
-                "DELETE FROM e2e_one_time_keys_json"
-                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
-                " AND key_id = ?"
-            )
-            for user_id, device_id, algorithm, key_id in delete:
-                log_kv(
-                    {
-                        "message": "Executing claim e2e_one_time_keys transaction on database."
-                    }
-                )
-                txn.execute(sql, (user_id, device_id, algorithm, key_id))
-                log_kv({"message": "finished executing and invalidating cache"})
-                self._invalidate_cache_and_stream(
-                    txn, self.count_e2e_one_time_keys, (user_id, device_id)
-                )
-            return result
-
-        return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
-
-    def delete_e2e_keys_by_device(self, user_id, device_id):
-        def delete_e2e_keys_by_device_txn(txn):
-            log_kv(
-                {
-                    "message": "Deleting keys for device",
-                    "device_id": device_id,
-                    "user_id": user_id,
-                }
-            )
-            self._simple_delete_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-            )
-            self._simple_delete_txn(
-                txn,
-                table="e2e_one_time_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
-        return self.runInteraction(
-            "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
-        )
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import importlib
 import platform
 
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite import Sqlite3Engine
 
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
 
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
     name = database_config["name"]
-    engine_class = SUPPORTED_MODULE.get(name, None)
 
-    if engine_class:
+    if name == "sqlite3":
+        import sqlite3
+
+        return Sqlite3Engine(sqlite3, database_config)
+
+    if name == "psycopg2":
         # pypy requires psycopg2cffi rather than psycopg2
-        if name == "psycopg2" and platform.python_implementation() == "PyPy":
-            name = "psycopg2cffi"
-        module = importlib.import_module(name)
-        return engine_class(module, database_config)
+        if platform.python_implementation() == "PyPy":
+            import psycopg2cffi as psycopg2  # type: ignore
+        else:
+            import psycopg2  # type: ignore
+
+        return PostgresEngine(psycopg2, database_config)
 
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
 
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
 
 
 class IncorrectDatabaseSetup(RuntimeError):
     pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+    def __init__(self, module, database_config: dict):
+        self.module = module
+
+    @property
+    @abc.abstractmethod
+    def single_threaded(self) -> bool:
+        ...
+
+    @property
+    @abc.abstractmethod
+    def can_native_upsert(self) -> bool:
+        """
+        Do we support native UPSERTs?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_tuple_comparison(self) -> bool:
+        """
+        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_using_any_list(self) -> bool:
+        """
+        Do we support using `a = ANY(?)` and passing a list
+        """
+        ...
+
+    @abc.abstractmethod
+    def check_database(
+        self, db_conn: ConnectionType, allow_outdated_version: bool = False
+    ) -> None:
+        ...
+
+    @abc.abstractmethod
+    def check_new_database(self, txn) -> None:
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
+        ...
+
+    @abc.abstractmethod
+    def convert_param_style(self, sql: str) -> str:
+        ...
+
+    @abc.abstractmethod
+    def on_new_connection(self, db_conn: ConnectionType) -> None:
+        ...
+
+    @abc.abstractmethod
+    def is_deadlock(self, error: Exception) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def is_connection_closed(self, conn: ConnectionType) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def lock_table(self, txn, table: str) -> None:
+        ...
+
+    @abc.abstractmethod
+    def get_next_state_group_id(self, txn) -> int:
+        """Returns an int that can be used as a new state_group ID
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def server_version(self) -> str:
+        """Gets a string giving the server version. For example: '3.22.0'
+        """
+        ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 289b6bc281..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,32 +13,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import IncorrectDatabaseSetup
+import logging
 
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 
-class PostgresEngine(object):
-    single_threaded = False
+logger = logging.getLogger(__name__)
 
+
+class PostgresEngine(BaseDatabaseEngine):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
         self.module.extensions.register_type(self.module.extensions.UNICODE)
-        self.synchronous_commit = database_config.get("synchronous_commit", True)
-        self._version = None  # unknown as yet
 
-    def check_database(self, txn):
-        txn.execute("SHOW SERVER_ENCODING")
-        rows = txn.fetchall()
-        if rows and rows[0][0] != "UTF8":
-            raise IncorrectDatabaseSetup(
-                "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
-                "See docs/postgres.rst for more information." % (rows[0][0],)
-            )
+        # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
+        # actually want to use bytes than wrap it in `bytearray`.
+        def _disable_bytes_adapter(_):
+            raise Exception("Passing bytes to DB is disabled.")
 
-    def convert_param_style(self, sql):
-        return sql.replace("?", "%s")
+        self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
+        self.synchronous_commit = database_config.get("synchronous_commit", True)
+        self._version = None  # unknown as yet
 
-    def on_new_connection(self, db_conn):
+    @property
+    def single_threaded(self) -> bool:
+        return False
 
+    def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
@@ -46,9 +46,64 @@ class PostgresEngine(object):
         self._version = db_conn.server_version
 
         # Are we on a supported PostgreSQL version?
-        if self._version < 90500:
+        if not allow_outdated_version and self._version < 90500:
             raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
 
+        with db_conn.cursor() as txn:
+            txn.execute("SHOW SERVER_ENCODING")
+            rows = txn.fetchall()
+            if rows and rows[0][0] != "UTF8":
+                raise IncorrectDatabaseSetup(
+                    "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+                    "See docs/postgres.md for more information." % (rows[0][0],)
+                )
+
+            txn.execute(
+                "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+            )
+            collation, ctype = txn.fetchone()
+            if collation != "C":
+                logger.warning(
+                    "Database has incorrect collation of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    collation,
+                )
+
+            if ctype != "C":
+                logger.warning(
+                    "Database has incorrect ctype of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    ctype,
+                )
+
+    def check_new_database(self, txn):
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
+
+        txn.execute(
+            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+        )
+        collation, ctype = txn.fetchone()
+
+        errors = []
+
+        if collation != "C":
+            errors.append("    - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
+
+        if ctype != "C":
+            errors.append("    - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+
+        if errors:
+            raise IncorrectDatabaseSetup(
+                "Database is incorrectly configured:\n\n%s\n\n"
+                "See docs/postgres.md for more information." % ("\n".join(errors))
+            )
+
+    def convert_param_style(self, sql):
+        return sql.replace("?", "%s")
+
+    def on_new_connection(self, db_conn):
         db_conn.set_isolation_level(
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
@@ -72,6 +127,19 @@ class PostgresEngine(object):
         """
         return True
 
+    @property
+    def supports_tuple_comparison(self):
+        """
+        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+        """
+        return True
+
+    @property
+    def supports_using_any_list(self):
+        """Do we support using `a = ANY(?)` and passing a list
+        """
+        return True
+
     def is_deadlock(self, error):
         if isinstance(error, self.module.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
@@ -99,8 +167,8 @@ class PostgresEngine(object):
         Returns:
             string
         """
-        # note that this is a bit of a hack because it relies on on_new_connection
-        # having been called at least once. Still, that should be a safe bet here.
+        # note that this is a bit of a hack because it relies on check_database
+        # having been called. Still, that should be a safe bet here.
         numver = self._version
         assert numver is not None
 
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index e9b9caa49a..215a949442 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,18 +12,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import struct
 import threading
+import typing
 
-from synapse.storage.prepare_database import prepare_database
+from synapse.storage.engines import BaseDatabaseEngine
 
+if typing.TYPE_CHECKING:
+    import sqlite3  # noqa: F401
 
-class Sqlite3Engine(object):
-    single_threaded = True
 
+class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
+
+        database = database_config.get("args", {}).get("database")
+        self._is_in_memory = database in (None, ":memory:",)
 
         # The current max state_group, or None if we haven't looked
         # in the DB yet.
@@ -31,6 +35,10 @@ class Sqlite3Engine(object):
         self._current_state_group_id_lock = threading.Lock()
 
     @property
+    def single_threaded(self) -> bool:
+        return True
+
+    @property
     def can_native_upsert(self):
         """
         Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -38,15 +46,46 @@ class Sqlite3Engine(object):
         """
         return self.module.sqlite_version_info >= (3, 24, 0)
 
-    def check_database(self, txn):
-        pass
+    @property
+    def supports_tuple_comparison(self):
+        """
+        Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
+        SQLite 3.15+.
+        """
+        return self.module.sqlite_version_info >= (3, 15, 0)
+
+    @property
+    def supports_using_any_list(self):
+        """Do we support using `a = ANY(?)` and passing a list
+        """
+        return False
+
+    def check_database(self, db_conn, allow_outdated_version: bool = False):
+        if not allow_outdated_version:
+            version = self.module.sqlite_version_info
+            if version < (3, 11, 0):
+                raise RuntimeError("Synapse requires sqlite 3.11 or above.")
+
+    def check_new_database(self, txn):
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
 
     def convert_param_style(self, sql):
         return sql
 
     def on_new_connection(self, db_conn):
-        prepare_database(db_conn, self, config=None)
+        # We need to import here to avoid an import loop.
+        from synapse.storage.prepare_database import prepare_database
+
+        if self._is_in_memory:
+            # In memory databases need to be rebuilt each time. Ideally we'd
+            # reuse the same connection as we do when starting up, but that
+            # would involve using adbapi before we have started the reactor.
+            prepare_database(db_conn, self, config=None)
+
         db_conn.create_function("rank", 1, _rank)
+        db_conn.execute("PRAGMA foreign_keys = ON;")
 
     def is_deadlock(self, error):
         return False
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
deleted file mode 100644
index ddf7ab6479..0000000000
--- a/synapse/storage/events.py
+++ /dev/null
@@ -1,2470 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import itertools
-import logging
-from collections import Counter as c_counter, OrderedDict, deque, namedtuple
-from functools import wraps
-
-from six import iteritems, text_type
-from six.moves import range
-
-from canonicaljson import encode_canonical_json, json
-from prometheus_client import Counter, Histogram
-
-from twisted.internet import defer
-
-import synapse.metrics
-from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
-from synapse.events.utils import prune_event_dict
-from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
-from synapse.logging.utils import log_function
-from synapse.metrics import BucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.state import StateGroupWorkerStore
-from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.metrics import Measure
-
-logger = logging.getLogger(__name__)
-
-persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
-event_counter = Counter(
-    "synapse_storage_events_persisted_events_sep",
-    "",
-    ["type", "origin_type", "origin_entity"],
-)
-
-# The number of times we are recalculating the current state
-state_delta_counter = Counter("synapse_storage_events_state_delta", "")
-
-# The number of times we are recalculating state when there is only a
-# single forward extremity
-state_delta_single_event_counter = Counter(
-    "synapse_storage_events_state_delta_single_event", ""
-)
-
-# The number of times we are reculating state when we could have resonably
-# calculated the delta when we calculated the state for an event we were
-# persisting.
-state_delta_reuse_delta_counter = Counter(
-    "synapse_storage_events_state_delta_reuse_delta", ""
-)
-
-# The number of forward extremities for each new event.
-forward_extremities_counter = Histogram(
-    "synapse_storage_events_forward_extremities_persisted",
-    "Number of forward extremities for each new event",
-    buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
-# The number of stale forward extremities for each new event. Stale extremities
-# are those that were in the previous set of extremities as well as the new.
-stale_forward_extremities_counter = Histogram(
-    "synapse_storage_events_stale_forward_extremities_persisted",
-    "Number of unchanged forward extremities for each new event",
-    buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
-
-def encode_json(json_object):
-    """
-    Encode a Python object as JSON and return it in a Unicode string.
-    """
-    out = frozendict_json_encoder.encode(json_object)
-    if isinstance(out, bytes):
-        out = out.decode("utf8")
-    return out
-
-
-class _EventPeristenceQueue(object):
-    """Queues up events so that they can be persisted in bulk with only one
-    concurrent transaction per room.
-    """
-
-    _EventPersistQueueItem = namedtuple(
-        "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
-    )
-
-    def __init__(self):
-        self._event_persist_queues = {}
-        self._currently_persisting_rooms = set()
-
-    def add_to_queue(self, room_id, events_and_contexts, backfilled):
-        """Add events to the queue, with the given persist_event options.
-
-        NB: due to the normal usage pattern of this method, it does *not*
-        follow the synapse logcontext rules, and leaves the logcontext in
-        place whether or not the returned deferred is ready.
-
-        Args:
-            room_id (str):
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
-
-        Returns:
-            defer.Deferred: a deferred which will resolve once the events are
-                persisted. Runs its callbacks *without* a logcontext.
-        """
-        queue = self._event_persist_queues.setdefault(room_id, deque())
-        if queue:
-            # if the last item in the queue has the same `backfilled` setting,
-            # we can just add these new events to that item.
-            end_item = queue[-1]
-            if end_item.backfilled == backfilled:
-                end_item.events_and_contexts.extend(events_and_contexts)
-                return end_item.deferred.observe()
-
-        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-
-        queue.append(
-            self._EventPersistQueueItem(
-                events_and_contexts=events_and_contexts,
-                backfilled=backfilled,
-                deferred=deferred,
-            )
-        )
-
-        return deferred.observe()
-
-    def handle_queue(self, room_id, per_item_callback):
-        """Attempts to handle the queue for a room if not already being handled.
-
-        The given callback will be invoked with for each item in the queue,
-        of type _EventPersistQueueItem. The per_item_callback will continuously
-        be called with new items, unless the queue becomnes empty. The return
-        value of the function will be given to the deferreds waiting on the item,
-        exceptions will be passed to the deferreds as well.
-
-        This function should therefore be called whenever anything is added
-        to the queue.
-
-        If another callback is currently handling the queue then it will not be
-        invoked.
-        """
-
-        if room_id in self._currently_persisting_rooms:
-            return
-
-        self._currently_persisting_rooms.add(room_id)
-
-        @defer.inlineCallbacks
-        def handle_queue_loop():
-            try:
-                queue = self._get_drainining_queue(room_id)
-                for item in queue:
-                    try:
-                        ret = yield per_item_callback(item)
-                    except Exception:
-                        with PreserveLoggingContext():
-                            item.deferred.errback()
-                    else:
-                        with PreserveLoggingContext():
-                            item.deferred.callback(ret)
-            finally:
-                queue = self._event_persist_queues.pop(room_id, None)
-                if queue:
-                    self._event_persist_queues[room_id] = queue
-                self._currently_persisting_rooms.discard(room_id)
-
-        # set handle_queue_loop off in the background
-        run_as_background_process("persist_events", handle_queue_loop)
-
-    def _get_drainining_queue(self, room_id):
-        queue = self._event_persist_queues.setdefault(room_id, deque())
-
-        try:
-            while True:
-                yield queue.popleft()
-        except IndexError:
-            # Queue has been drained.
-            pass
-
-
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
-def _retry_on_integrity_error(func):
-    """Wraps a database function so that it gets retried on IntegrityError,
-    with `delete_existing=True` passed in.
-
-    Args:
-        func: function that returns a Deferred and accepts a `delete_existing` arg
-    """
-
-    @wraps(func)
-    @defer.inlineCallbacks
-    def f(self, *args, **kwargs):
-        try:
-            res = yield func(self, *args, **kwargs)
-        except self.database_engine.module.IntegrityError:
-            logger.exception("IntegrityError, retrying.")
-            res = yield func(self, *args, delete_existing=True, **kwargs)
-        return res
-
-    return f
-
-
-# inherits from EventFederationStore so that we can call _update_backward_extremities
-# and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(
-    StateGroupWorkerStore,
-    EventFederationStore,
-    EventsWorkerStore,
-    BackgroundUpdateStore,
-):
-    def __init__(self, db_conn, hs):
-        super(EventsStore, self).__init__(db_conn, hs)
-
-        self._event_persist_queue = _EventPeristenceQueue()
-        self._state_resolution_handler = hs.get_state_resolution_handler()
-
-        # Collect metrics on the number of forward extremities that exist.
-        # Counter of number of extremities to count
-        self._current_forward_extremities_amount = c_counter()
-
-        BucketCollector(
-            "synapse_forward_extremities",
-            lambda: self._current_forward_extremities_amount,
-            buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
-        )
-
-        # Read the extrems every 60 minutes
-        def read_forward_extremities():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "read_forward_extremities", self._read_forward_extremities
-            )
-
-        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
-
-        def _censor_redactions():
-            return run_as_background_process(
-                "_censor_redactions", self._censor_redactions
-            )
-
-        if self.hs.config.redaction_retention_period is not None:
-            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
-
-    @defer.inlineCallbacks
-    def _read_forward_extremities(self):
-        def fetch(txn):
-            txn.execute(
-                """
-                select count(*) c from event_forward_extremities
-                group by room_id
-                """
-            )
-            return txn.fetchall()
-
-        res = yield self.runInteraction("read_forward_extremities", fetch)
-        self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
-
-    @defer.inlineCallbacks
-    def persist_events(self, events_and_contexts, backfilled=False):
-        """
-        Write events to the database
-        Args:
-            events_and_contexts: list of tuples of (event, context)
-            backfilled (bool): Whether the results are retrieved from federation
-                via backfill or not. Used to determine if they're "new" events
-                which might update the current state etc.
-
-        Returns:
-            Deferred[int]: the stream ordering of the latest persisted event
-        """
-        partitioned = {}
-        for event, ctx in events_and_contexts:
-            partitioned.setdefault(event.room_id, []).append((event, ctx))
-
-        deferreds = []
-        for room_id, evs_ctxs in iteritems(partitioned):
-            d = self._event_persist_queue.add_to_queue(
-                room_id, evs_ctxs, backfilled=backfilled
-            )
-            deferreds.append(d)
-
-        for room_id in partitioned:
-            self._maybe_start_persisting(room_id)
-
-        yield make_deferred_yieldable(
-            defer.gatherResults(deferreds, consumeErrors=True)
-        )
-
-        max_persisted_id = yield self._stream_id_gen.get_current_token()
-
-        return max_persisted_id
-
-    @defer.inlineCallbacks
-    @log_function
-    def persist_event(self, event, context, backfilled=False):
-        """
-
-        Args:
-            event (EventBase):
-            context (EventContext):
-            backfilled (bool):
-
-        Returns:
-            Deferred: resolves to (int, int): the stream ordering of ``event``,
-            and the stream ordering of the latest persisted event
-        """
-        deferred = self._event_persist_queue.add_to_queue(
-            event.room_id, [(event, context)], backfilled=backfilled
-        )
-
-        self._maybe_start_persisting(event.room_id)
-
-        yield make_deferred_yieldable(deferred)
-
-        max_persisted_id = yield self._stream_id_gen.get_current_token()
-        return (event.internal_metadata.stream_ordering, max_persisted_id)
-
-    def _maybe_start_persisting(self, room_id):
-        @defer.inlineCallbacks
-        def persisting_queue(item):
-            with Measure(self._clock, "persist_events"):
-                yield self._persist_events(
-                    item.events_and_contexts, backfilled=item.backfilled
-                )
-
-        self._event_persist_queue.handle_queue(room_id, persisting_queue)
-
-    @_retry_on_integrity_error
-    @defer.inlineCallbacks
-    def _persist_events(
-        self, events_and_contexts, backfilled=False, delete_existing=False
-    ):
-        """Persist events to db
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
-            delete_existing (bool):
-
-        Returns:
-            Deferred: resolves when the events have been persisted
-        """
-        if not events_and_contexts:
-            return
-
-        chunks = [
-            events_and_contexts[x : x + 100]
-            for x in range(0, len(events_and_contexts), 100)
-        ]
-
-        for chunk in chunks:
-            # We can't easily parallelize these since different chunks
-            # might contain the same event. :(
-
-            # NB: Assumes that we are only persisting events for one room
-            # at a time.
-
-            # map room_id->list[event_ids] giving the new forward
-            # extremities in each room
-            new_forward_extremeties = {}
-
-            # map room_id->(type,state_key)->event_id tracking the full
-            # state in each room after adding these events.
-            # This is simply used to prefill the get_current_state_ids
-            # cache
-            current_state_for_room = {}
-
-            # map room_id->(to_delete, to_insert) where to_delete is a list
-            # of type/state keys to remove from current state, and to_insert
-            # is a map (type,key)->event_id giving the state delta in each
-            # room
-            state_delta_for_room = {}
-
-            if not backfilled:
-                with Measure(self._clock, "_calculate_state_and_extrem"):
-                    # Work out the new "current state" for each room.
-                    # We do this by working out what the new extremities are and then
-                    # calculating the state from that.
-                    events_by_room = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in iteritems(events_by_room):
-                        latest_event_ids = yield self.get_latest_event_ids_in_room(
-                            room_id
-                        )
-                        new_latest_event_ids = yield self._calculate_new_extremities(
-                            room_id, ev_ctx_rm, latest_event_ids
-                        )
-
-                        latest_event_ids = set(latest_event_ids)
-                        if new_latest_event_ids == latest_event_ids:
-                            # No change in extremities, so no change in state
-                            continue
-
-                        # there should always be at least one forward extremity.
-                        # (except during the initial persistence of the send_join
-                        # results, in which case there will be no existing
-                        # extremities, so we'll `continue` above and skip this bit.)
-                        assert new_latest_event_ids, "No forward extremities left!"
-
-                        new_forward_extremeties[room_id] = new_latest_event_ids
-
-                        len_1 = (
-                            len(latest_event_ids) == 1
-                            and len(new_latest_event_ids) == 1
-                        )
-                        if len_1:
-                            all_single_prev_not_state = all(
-                                len(event.prev_event_ids()) == 1
-                                and not event.is_state()
-                                for event, ctx in ev_ctx_rm
-                            )
-                            # Don't bother calculating state if they're just
-                            # a long chain of single ancestor non-state events.
-                            if all_single_prev_not_state:
-                                continue
-
-                        state_delta_counter.inc()
-                        if len(new_latest_event_ids) == 1:
-                            state_delta_single_event_counter.inc()
-
-                            # This is a fairly handwavey check to see if we could
-                            # have guessed what the delta would have been when
-                            # processing one of these events.
-                            # What we're interested in is if the latest extremities
-                            # were the same when we created the event as they are
-                            # now. When this server creates a new event (as opposed
-                            # to receiving it over federation) it will use the
-                            # forward extremities as the prev_events, so we can
-                            # guess this by looking at the prev_events and checking
-                            # if they match the current forward extremities.
-                            for ev, _ in ev_ctx_rm:
-                                prev_event_ids = set(ev.prev_event_ids())
-                                if latest_event_ids == prev_event_ids:
-                                    state_delta_reuse_delta_counter.inc()
-                                    break
-
-                        logger.info("Calculating state delta for room %s", room_id)
-                        with Measure(
-                            self._clock, "persist_events.get_new_state_after_events"
-                        ):
-                            res = yield self._get_new_state_after_events(
-                                room_id,
-                                ev_ctx_rm,
-                                latest_event_ids,
-                                new_latest_event_ids,
-                            )
-                            current_state, delta_ids = res
-
-                        # If either are not None then there has been a change,
-                        # and we need to work out the delta (or use that
-                        # given)
-                        if delta_ids is not None:
-                            # If there is a delta we know that we've
-                            # only added or replaced state, never
-                            # removed keys entirely.
-                            state_delta_for_room[room_id] = ([], delta_ids)
-                        elif current_state is not None:
-                            with Measure(
-                                self._clock, "persist_events.calculate_state_delta"
-                            ):
-                                delta = yield self._calculate_state_delta(
-                                    room_id, current_state
-                                )
-                            state_delta_for_room[room_id] = delta
-
-                        # If we have the current_state then lets prefill
-                        # the cache with it.
-                        if current_state is not None:
-                            current_state_for_room[room_id] = current_state
-
-            # We want to calculate the stream orderings as late as possible, as
-            # we only notify after all events with a lesser stream ordering have
-            # been persisted. I.e. if we spend 10s inside the with block then
-            # that will delay all subsequent events from being notified about.
-            # Hence why we do it down here rather than wrapping the entire
-            # function.
-            #
-            # Its safe to do this after calculating the state deltas etc as we
-            # only need to protect the *persistence* of the events. This is to
-            # ensure that queries of the form "fetch events since X" don't
-            # return events and stream positions after events that are still in
-            # flight, as otherwise subsequent requests "fetch event since Y"
-            # will not return those events.
-            #
-            # Note: Multiple instances of this function cannot be in flight at
-            # the same time for the same room.
-            if backfilled:
-                stream_ordering_manager = self._backfill_id_gen.get_next_mult(
-                    len(chunk)
-                )
-            else:
-                stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk))
-
-            with stream_ordering_manager as stream_orderings:
-                for (event, context), stream in zip(chunk, stream_orderings):
-                    event.internal_metadata.stream_ordering = stream
-
-                yield self.runInteraction(
-                    "persist_events",
-                    self._persist_events_txn,
-                    events_and_contexts=chunk,
-                    backfilled=backfilled,
-                    delete_existing=delete_existing,
-                    state_delta_for_room=state_delta_for_room,
-                    new_forward_extremeties=new_forward_extremeties,
-                )
-                persist_event_counter.inc(len(chunk))
-
-                if not backfilled:
-                    # backfilled events have negative stream orderings, so we don't
-                    # want to set the event_persisted_position to that.
-                    synapse.metrics.event_persisted_position.set(
-                        chunk[-1][0].internal_metadata.stream_ordering
-                    )
-
-                for event, context in chunk:
-                    if context.app_service:
-                        origin_type = "local"
-                        origin_entity = context.app_service.id
-                    elif self.hs.is_mine_id(event.sender):
-                        origin_type = "local"
-                        origin_entity = "*client*"
-                    else:
-                        origin_type = "remote"
-                        origin_entity = get_domain_from_id(event.sender)
-
-                    event_counter.labels(event.type, origin_type, origin_entity).inc()
-
-                for room_id, new_state in iteritems(current_state_for_room):
-                    self.get_current_state_ids.prefill((room_id,), new_state)
-
-                for room_id, latest_event_ids in iteritems(new_forward_extremeties):
-                    self.get_latest_event_ids_in_room.prefill(
-                        (room_id,), list(latest_event_ids)
-                    )
-
-    @defer.inlineCallbacks
-    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
-        """Calculates the new forward extremities for a room given events to
-        persist.
-
-        Assumes that we are only persisting events for one room at a time.
-        """
-
-        # we're only interested in new events which aren't outliers and which aren't
-        # being rejected.
-        new_events = [
-            event
-            for event, ctx in event_contexts
-            if not event.internal_metadata.is_outlier()
-            and not ctx.rejected
-            and not event.internal_metadata.is_soft_failed()
-        ]
-
-        latest_event_ids = set(latest_event_ids)
-
-        # start with the existing forward extremities
-        result = set(latest_event_ids)
-
-        # add all the new events to the list
-        result.update(event.event_id for event in new_events)
-
-        # Now remove all events which are prev_events of any of the new events
-        result.difference_update(
-            e_id for event in new_events for e_id in event.prev_event_ids()
-        )
-
-        # Remove any events which are prev_events of any existing events.
-        existing_prevs = yield self._get_events_which_are_prevs(result)
-        result.difference_update(existing_prevs)
-
-        # Finally handle the case where the new events have soft-failed prev
-        # events. If they do we need to remove them and their prev events,
-        # otherwise we end up with dangling extremities.
-        existing_prevs = yield self._get_prevs_before_rejected(
-            e_id for event in new_events for e_id in event.prev_event_ids()
-        )
-        result.difference_update(existing_prevs)
-
-        # We only update metrics for events that change forward extremities
-        # (e.g. we ignore backfill/outliers/etc)
-        if result != latest_event_ids:
-            forward_extremities_counter.observe(len(result))
-            stale = latest_event_ids & result
-            stale_forward_extremities_counter.observe(len(stale))
-
-        return result
-
-    @defer.inlineCallbacks
-    def _get_events_which_are_prevs(self, event_ids):
-        """Filter the supplied list of event_ids to get those which are prev_events of
-        existing (non-outlier/rejected) events.
-
-        Args:
-            event_ids (Iterable[str]): event ids to filter
-
-        Returns:
-            Deferred[List[str]]: filtered event ids
-        """
-        results = []
-
-        def _get_events_which_are_prevs_txn(txn, batch):
-            sql = """
-            SELECT prev_event_id, internal_metadata
-            FROM event_edges
-                INNER JOIN events USING (event_id)
-                LEFT JOIN rejections USING (event_id)
-                LEFT JOIN event_json USING (event_id)
-            WHERE
-                prev_event_id IN (%s)
-                AND NOT events.outlier
-                AND rejections.event_id IS NULL
-            """ % (
-                ",".join("?" for _ in batch),
-            )
-
-            txn.execute(sql, batch)
-            results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
-
-        for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
-                "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
-            )
-
-        return results
-
-    @defer.inlineCallbacks
-    def _get_prevs_before_rejected(self, event_ids):
-        """Get soft-failed ancestors to remove from the extremities.
-
-        Given a set of events, find all those that have been soft-failed or
-        rejected. Returns those soft failed/rejected events and their prev
-        events (whether soft-failed/rejected or not), and recurses up the
-        prev-event graph until it finds no more soft-failed/rejected events.
-
-        This is used to find extremities that are ancestors of new events, but
-        are separated by soft failed events.
-
-        Args:
-            event_ids (Iterable[str]): Events to find prev events for. Note
-                that these must have already been persisted.
-
-        Returns:
-            Deferred[set[str]]
-        """
-
-        # The set of event_ids to return. This includes all soft-failed events
-        # and their prev events.
-        existing_prevs = set()
-
-        def _get_prevs_before_rejected_txn(txn, batch):
-            to_recursively_check = batch
-
-            while to_recursively_check:
-                sql = """
-                SELECT
-                    event_id, prev_event_id, internal_metadata,
-                    rejections.event_id IS NOT NULL
-                FROM event_edges
-                    INNER JOIN events USING (event_id)
-                    LEFT JOIN rejections USING (event_id)
-                    LEFT JOIN event_json USING (event_id)
-                WHERE
-                    event_id IN (%s)
-                    AND NOT events.outlier
-                """ % (
-                    ",".join("?" for _ in to_recursively_check),
-                )
-
-                txn.execute(sql, to_recursively_check)
-                to_recursively_check = []
-
-                for event_id, prev_event_id, metadata, rejected in txn:
-                    if prev_event_id in existing_prevs:
-                        continue
-
-                    soft_failed = json.loads(metadata).get("soft_failed")
-                    if soft_failed or rejected:
-                        to_recursively_check.append(prev_event_id)
-                        existing_prevs.add(prev_event_id)
-
-        for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
-                "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
-            )
-
-        return existing_prevs
-
-    @defer.inlineCallbacks
-    def _get_new_state_after_events(
-        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
-    ):
-        """Calculate the current state dict after adding some new events to
-        a room
-
-        Args:
-            room_id (str):
-                room to which the events are being added. Used for logging etc
-
-            events_context (list[(EventBase, EventContext)]):
-                events and contexts which are being added to the room
-
-            old_latest_event_ids (iterable[str]):
-                the old forward extremities for the room.
-
-            new_latest_event_ids (iterable[str]):
-                the new forward extremities for the room.
-
-        Returns:
-            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
-            Returns a tuple of two state maps, the first being the full new current
-            state and the second being the delta to the existing current state.
-            If both are None then there has been no change.
-
-            If there has been a change then we only return the delta if its
-            already been calculated. Conversely if we do know the delta then
-            the new current state is only returned if we've already calculated
-            it.
-        """
-        # map from state_group to ((type, key) -> event_id) state map
-        state_groups_map = {}
-
-        # Map from (prev state group, new state group) -> delta state dict
-        state_group_deltas = {}
-
-        for ev, ctx in events_context:
-            if ctx.state_group is None:
-                # This should only happen for outlier events.
-                if not ev.internal_metadata.is_outlier():
-                    raise Exception(
-                        "Context for new event %s has no state "
-                        "group" % (ev.event_id,)
-                    )
-                continue
-
-            if ctx.state_group in state_groups_map:
-                continue
-
-            # We're only interested in pulling out state that has already
-            # been cached in the context. We'll pull stuff out of the DB later
-            # if necessary.
-            current_state_ids = ctx.get_cached_current_state_ids()
-            if current_state_ids is not None:
-                state_groups_map[ctx.state_group] = current_state_ids
-
-            if ctx.prev_group:
-                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
-
-        # We need to map the event_ids to their state groups. First, let's
-        # check if the event is one we're persisting, in which case we can
-        # pull the state group from its context.
-        # Otherwise we need to pull the state group from the database.
-
-        # Set of events we need to fetch groups for. (We know none of the old
-        # extremities are going to be in events_context).
-        missing_event_ids = set(old_latest_event_ids)
-
-        event_id_to_state_group = {}
-        for event_id in new_latest_event_ids:
-            # First search in the list of new events we're adding.
-            for ev, ctx in events_context:
-                if event_id == ev.event_id and ctx.state_group is not None:
-                    event_id_to_state_group[event_id] = ctx.state_group
-                    break
-            else:
-                # If we couldn't find it, then we'll need to pull
-                # the state from the database
-                missing_event_ids.add(event_id)
-
-        if missing_event_ids:
-            # Now pull out the state groups for any missing events from DB
-            event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
-            event_id_to_state_group.update(event_to_groups)
-
-        # State groups of old_latest_event_ids
-        old_state_groups = set(
-            event_id_to_state_group[evid] for evid in old_latest_event_ids
-        )
-
-        # State groups of new_latest_event_ids
-        new_state_groups = set(
-            event_id_to_state_group[evid] for evid in new_latest_event_ids
-        )
-
-        # If they old and new groups are the same then we don't need to do
-        # anything.
-        if old_state_groups == new_state_groups:
-            return None, None
-
-        if len(new_state_groups) == 1 and len(old_state_groups) == 1:
-            # If we're going from one state group to another, lets check if
-            # we have a delta for that transition. If we do then we can just
-            # return that.
-
-            new_state_group = next(iter(new_state_groups))
-            old_state_group = next(iter(old_state_groups))
-
-            delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
-            if delta_ids is not None:
-                # We have a delta from the existing to new current state,
-                # so lets just return that. If we happen to already have
-                # the current state in memory then lets also return that,
-                # but it doesn't matter if we don't.
-                new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids
-
-        # Now that we have calculated new_state_groups we need to get
-        # their state IDs so we can resolve to a single state set.
-        missing_state = new_state_groups - set(state_groups_map)
-        if missing_state:
-            group_to_state = yield self._get_state_for_groups(missing_state)
-            state_groups_map.update(group_to_state)
-
-        if len(new_state_groups) == 1:
-            # If there is only one state group, then we know what the current
-            # state is.
-            return state_groups_map[new_state_groups.pop()], None
-
-        # Ok, we need to defer to the state handler to resolve our state sets.
-
-        state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
-
-        events_map = {ev.event_id: ev for ev, _ in events_context}
-
-        # We need to get the room version, which is in the create event.
-        # Normally that'd be in the database, but its also possible that we're
-        # currently trying to persist it.
-        room_version = None
-        for ev, _ in events_context:
-            if ev.type == EventTypes.Create and ev.state_key == "":
-                room_version = ev.content.get("room_version", "1")
-                break
-
-        if not room_version:
-            room_version = yield self.get_room_version(room_id)
-
-        logger.debug("calling resolve_state_groups from preserve_events")
-        res = yield self._state_resolution_handler.resolve_state_groups(
-            room_id,
-            room_version,
-            state_groups,
-            events_map,
-            state_res_store=StateResolutionStore(self),
-        )
-
-        return res.state, None
-
-    @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, current_state):
-        """Calculate the new state deltas for a room.
-
-        Assumes that we are only persisting events for one room at a time.
-
-        Returns:
-            tuple[list, dict] (to_delete, to_insert): where to_delete are the
-            type/state_keys to remove from current_state_events and `to_insert`
-            are the updates to current_state_events.
-        """
-        existing_state = yield self.get_current_state_ids(room_id)
-
-        to_delete = [key for key in existing_state if key not in current_state]
-
-        to_insert = {
-            key: ev_id
-            for key, ev_id in iteritems(current_state)
-            if ev_id != existing_state.get(key)
-        }
-
-        return to_delete, to_insert
-
-    @log_function
-    def _persist_events_txn(
-        self,
-        txn,
-        events_and_contexts,
-        backfilled,
-        delete_existing=False,
-        state_delta_for_room={},
-        new_forward_extremeties={},
-    ):
-        """Insert some number of room events into the necessary database tables.
-
-        Rejected events are only inserted into the events table, the events_json table,
-        and the rejections table. Things reading from those table will need to check
-        whether the event was rejected.
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]):
-                events to persist
-            backfilled (bool): True if the events were backfilled
-            delete_existing (bool): True to purge existing table rows for the
-                events from the database. This is useful when retrying due to
-                IntegrityError.
-            state_delta_for_room (dict[str, (list, dict)]):
-                The current-state delta for each room. For each room, a tuple
-                (to_delete, to_insert), being a list of type/state keys to be
-                removed from the current state, and a state set to be added to
-                the current state.
-            new_forward_extremeties (dict[str, list[str]]):
-                The new forward extremities for each room. For each room, a
-                list of the event ids which are the forward extremities.
-
-        """
-        all_events_and_contexts = events_and_contexts
-
-        min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
-        max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
-
-        self._update_forward_extremities_txn(
-            txn,
-            new_forward_extremities=new_forward_extremeties,
-            max_stream_order=max_stream_order,
-        )
-
-        # Ensure that we don't have the same event twice.
-        events_and_contexts = self._filter_events_and_contexts_for_duplicates(
-            events_and_contexts
-        )
-
-        self._update_room_depths_txn(
-            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
-        )
-
-        # _update_outliers_txn filters out any events which have already been
-        # persisted, and returns the filtered list.
-        events_and_contexts = self._update_outliers_txn(
-            txn, events_and_contexts=events_and_contexts
-        )
-
-        # From this point onwards the events are only events that we haven't
-        # seen before.
-
-        if delete_existing:
-            # For paranoia reasons, we go and delete all the existing entries
-            # for these events so we can reinsert them.
-            # This gets around any problems with some tables already having
-            # entries.
-            self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
-
-        self._store_event_txn(txn, events_and_contexts=events_and_contexts)
-
-        # Insert into event_to_state_groups.
-        self._store_event_state_mappings_txn(txn, events_and_contexts)
-
-        # We want to store event_auth mappings for rejected events, as they're
-        # used in state res v2.
-        # This is only necessary if the rejected event appears in an accepted
-        # event's auth chain, but its easier for now just to store them (and
-        # it doesn't take much storage compared to storing the entire event
-        # anyway).
-        self._simple_insert_many_txn(
-            txn,
-            table="event_auth",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "auth_id": auth_id,
-                }
-                for event, _ in events_and_contexts
-                for auth_id in event.auth_event_ids()
-                if event.is_state()
-            ],
-        )
-
-        # _store_rejected_events_txn filters out any events which were
-        # rejected, and returns the filtered list.
-        events_and_contexts = self._store_rejected_events_txn(
-            txn, events_and_contexts=events_and_contexts
-        )
-
-        # From this point onwards the events are only ones that weren't
-        # rejected.
-
-        self._update_metadata_tables_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
-            all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
-        )
-
-        # We call this last as it assumes we've inserted the events into
-        # room_memberships, where applicable.
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
-
-    def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
-        for room_id, current_state_tuple in iteritems(state_delta_by_room):
-            to_delete, to_insert = current_state_tuple
-
-            # First we add entries to the current_state_delta_stream. We
-            # do this before updating the current_state_events table so
-            # that we can use it to calculate the `prev_event_id`. (This
-            # allows us to not have to pull out the existing state
-            # unnecessarily).
-            #
-            # The stream_id for the update is chosen to be the minimum of the stream_ids
-            # for the batch of the events that we are persisting; that means we do not
-            # end up in a situation where workers see events before the
-            # current_state_delta updates.
-            #
-            sql = """
-                INSERT INTO current_state_delta_stream
-                (stream_id, room_id, type, state_key, event_id, prev_event_id)
-                SELECT ?, ?, ?, ?, ?, (
-                    SELECT event_id FROM current_state_events
-                    WHERE room_id = ? AND type = ? AND state_key = ?
-                )
-            """
-            txn.executemany(
-                sql,
-                (
-                    (
-                        stream_id,
-                        room_id,
-                        etype,
-                        state_key,
-                        None,
-                        room_id,
-                        etype,
-                        state_key,
-                    )
-                    for etype, state_key in to_delete
-                    # We sanity check that we're deleting rather than updating
-                    if (etype, state_key) not in to_insert
-                ),
-            )
-            txn.executemany(
-                sql,
-                (
-                    (
-                        stream_id,
-                        room_id,
-                        etype,
-                        state_key,
-                        ev_id,
-                        room_id,
-                        etype,
-                        state_key,
-                    )
-                    for (etype, state_key), ev_id in iteritems(to_insert)
-                ),
-            )
-
-            # Now we actually update the current_state_events table
-
-            txn.executemany(
-                "DELETE FROM current_state_events"
-                " WHERE room_id = ? AND type = ? AND state_key = ?",
-                (
-                    (room_id, etype, state_key)
-                    for etype, state_key in itertools.chain(to_delete, to_insert)
-                ),
-            )
-
-            # We include the membership in the current state table, hence we do
-            # a lookup when we insert. This assumes that all events have already
-            # been inserted into room_memberships.
-            txn.executemany(
-                """INSERT INTO current_state_events
-                    (room_id, type, state_key, event_id, membership)
-                VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
-                """,
-                [
-                    (room_id, key[0], key[1], ev_id, ev_id)
-                    for key, ev_id in iteritems(to_insert)
-                ],
-            )
-
-            txn.call_after(
-                self._curr_state_delta_stream_cache.entity_has_changed,
-                room_id,
-                stream_id,
-            )
-
-            # Invalidate the various caches
-
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invlidate the caches for all
-            # those users.
-            members_changed = set(
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            )
-
-            for member in members_changed:
-                txn.call_after(
-                    self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
-                )
-
-            self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
-
-    def _update_forward_extremities_txn(
-        self, txn, new_forward_extremities, max_stream_order
-    ):
-        for room_id, new_extrem in iteritems(new_forward_extremities):
-            self._simple_delete_txn(
-                txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
-            )
-            txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
-
-        self._simple_insert_many_txn(
-            txn,
-            table="event_forward_extremities",
-            values=[
-                {"event_id": ev_id, "room_id": room_id}
-                for room_id, new_extrem in iteritems(new_forward_extremities)
-                for ev_id in new_extrem
-            ],
-        )
-        # We now insert into stream_ordering_to_exterm a mapping from room_id,
-        # new stream_ordering to new forward extremeties in the room.
-        # This allows us to later efficiently look up the forward extremeties
-        # for a room before a given stream_ordering
-        self._simple_insert_many_txn(
-            txn,
-            table="stream_ordering_to_exterm",
-            values=[
-                {
-                    "room_id": room_id,
-                    "event_id": event_id,
-                    "stream_ordering": max_stream_order,
-                }
-                for room_id, new_extrem in iteritems(new_forward_extremities)
-                for event_id in new_extrem
-            ],
-        )
-
-    @classmethod
-    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
-        """Ensure that we don't have the same event twice.
-
-        Pick the earliest non-outlier if there is one, else the earliest one.
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]):
-        Returns:
-            list[(EventBase, EventContext)]: filtered list
-        """
-        new_events_and_contexts = OrderedDict()
-        for event, context in events_and_contexts:
-            prev_event_context = new_events_and_contexts.get(event.event_id)
-            if prev_event_context:
-                if not event.internal_metadata.is_outlier():
-                    if prev_event_context[0].internal_metadata.is_outlier():
-                        # To ensure correct ordering we pop, as OrderedDict is
-                        # ordered by first insertion.
-                        new_events_and_contexts.pop(event.event_id, None)
-                        new_events_and_contexts[event.event_id] = (event, context)
-            else:
-                new_events_and_contexts[event.event_id] = (event, context)
-        return list(new_events_and_contexts.values())
-
-    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
-        """Update min_depth for each room
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            backfilled (bool): True if the events were backfilled
-        """
-        depth_updates = {}
-        for event, context in events_and_contexts:
-            # Remove the any existing cache entries for the event_ids
-            txn.call_after(self._invalidate_get_event_cache, event.event_id)
-            if not backfilled:
-                txn.call_after(
-                    self._events_stream_cache.entity_has_changed,
-                    event.room_id,
-                    event.internal_metadata.stream_ordering,
-                )
-
-            if not event.internal_metadata.is_outlier() and not context.rejected:
-                depth_updates[event.room_id] = max(
-                    event.depth, depth_updates.get(event.room_id, event.depth)
-                )
-
-        for room_id, depth in iteritems(depth_updates):
-            self._update_min_depth_for_room_txn(txn, room_id, depth)
-
-    def _update_outliers_txn(self, txn, events_and_contexts):
-        """Update any outliers with new event info.
-
-        This turns outliers into ex-outliers (unless the new event was
-        rejected).
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-
-        Returns:
-            list[(EventBase, EventContext)] new list, without events which
-            are already in the events table.
-        """
-        txn.execute(
-            "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
-            % (",".join(["?"] * len(events_and_contexts)),),
-            [event.event_id for event, _ in events_and_contexts],
-        )
-
-        have_persisted = {event_id: outlier for event_id, outlier in txn}
-
-        to_remove = set()
-        for event, context in events_and_contexts:
-            if event.event_id not in have_persisted:
-                continue
-
-            to_remove.add(event)
-
-            if context.rejected:
-                # If the event is rejected then we don't care if the event
-                # was an outlier or not.
-                continue
-
-            outlier_persisted = have_persisted[event.event_id]
-            if not event.internal_metadata.is_outlier() and outlier_persisted:
-                # We received a copy of an event that we had already stored as
-                # an outlier in the database. We now have some state at that
-                # so we need to update the state_groups table with that state.
-
-                # insert into event_to_state_groups.
-                try:
-                    self._store_event_state_mappings_txn(txn, ((event, context),))
-                except Exception:
-                    logger.exception("")
-                    raise
-
-                metadata_json = encode_json(event.internal_metadata.get_dict())
-
-                sql = (
-                    "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
-                )
-                txn.execute(sql, (metadata_json, event.event_id))
-
-                # Add an entry to the ex_outlier_stream table to replicate the
-                # change in outlier status to our workers.
-                stream_order = event.internal_metadata.stream_ordering
-                state_group_id = context.state_group
-                self._simple_insert_txn(
-                    txn,
-                    table="ex_outlier_stream",
-                    values={
-                        "event_stream_ordering": stream_order,
-                        "event_id": event.event_id,
-                        "state_group": state_group_id,
-                    },
-                )
-
-                sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
-                txn.execute(sql, (False, event.event_id))
-
-                # Update the event_backward_extremities table now that this
-                # event isn't an outlier any more.
-                self._update_backward_extremeties(txn, [event])
-
-        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
-
-    @classmethod
-    def _delete_existing_rows_txn(cls, txn, events_and_contexts):
-        if not events_and_contexts:
-            # nothing to do here
-            return
-
-        logger.info("Deleting existing")
-
-        for table in (
-            "events",
-            "event_auth",
-            "event_json",
-            "event_edges",
-            "event_forward_extremities",
-            "event_reference_hashes",
-            "event_search",
-            "event_to_state_groups",
-            "local_invites",
-            "state_events",
-            "rejections",
-            "redactions",
-            "room_memberships",
-        ):
-            txn.executemany(
-                "DELETE FROM %s WHERE event_id = ?" % (table,),
-                [(ev.event_id,) for ev, _ in events_and_contexts],
-            )
-
-        for table in ("event_push_actions",):
-            txn.executemany(
-                "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
-                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
-            )
-
-    def _store_event_txn(self, txn, events_and_contexts):
-        """Insert new events into the event and event_json tables
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-        """
-
-        if not events_and_contexts:
-            # nothing to do here
-            return
-
-        def event_dict(event):
-            d = event.get_dict()
-            d.pop("redacted", None)
-            d.pop("redacted_because", None)
-            return d
-
-        self._simple_insert_many_txn(
-            txn,
-            table="event_json",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "internal_metadata": encode_json(
-                        event.internal_metadata.get_dict()
-                    ),
-                    "json": encode_json(event_dict(event)),
-                    "format_version": event.format_version,
-                }
-                for event, _ in events_and_contexts
-            ],
-        )
-
-        self._simple_insert_many_txn(
-            txn,
-            table="events",
-            values=[
-                {
-                    "stream_ordering": event.internal_metadata.stream_ordering,
-                    "topological_ordering": event.depth,
-                    "depth": event.depth,
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "type": event.type,
-                    "processed": True,
-                    "outlier": event.internal_metadata.is_outlier(),
-                    "origin_server_ts": int(event.origin_server_ts),
-                    "received_ts": self._clock.time_msec(),
-                    "sender": event.sender,
-                    "contains_url": (
-                        "url" in event.content
-                        and isinstance(event.content["url"], text_type)
-                    ),
-                }
-                for event, _ in events_and_contexts
-            ],
-        )
-
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
-        """Add rows to the 'rejections' table for received events which were
-        rejected
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-
-        Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
-        """
-        # Remove the rejected events from the list now that we've added them
-        # to the events table and the events_json table.
-        to_remove = set()
-        for event, context in events_and_contexts:
-            if context.rejected:
-                # Insert the event_id into the rejections table
-                self._store_rejections_txn(txn, event.event_id, context.rejected)
-                to_remove.add(event)
-
-        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
-
-    def _update_metadata_tables_txn(
-        self, txn, events_and_contexts, all_events_and_contexts, backfilled
-    ):
-        """Update all the miscellaneous tables for new events
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
-            backfilled (bool): True if the events were backfilled
-        """
-
-        # Insert all the push actions into the event_push_actions table.
-        self._set_push_actions_for_event_and_users_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
-            all_events_and_contexts=all_events_and_contexts,
-        )
-
-        if not events_and_contexts:
-            # nothing to do here
-            return
-
-        for event, context in events_and_contexts:
-            if event.type == EventTypes.Redaction and event.redacts is not None:
-                # Remove the entries in the event_push_actions table for the
-                # redacted event.
-                self._remove_push_actions_for_event_id_txn(
-                    txn, event.room_id, event.redacts
-                )
-
-                # Remove from relations table.
-                self._handle_redaction(txn, event.redacts)
-
-        # Update the event_forward_extremities, event_backward_extremities and
-        # event_edges tables.
-        self._handle_mult_prev_events(
-            txn, events=[event for event, _ in events_and_contexts]
-        )
-
-        for event, _ in events_and_contexts:
-            if event.type == EventTypes.Name:
-                # Insert into the event_search table.
-                self._store_room_name_txn(txn, event)
-            elif event.type == EventTypes.Topic:
-                # Insert into the event_search table.
-                self._store_room_topic_txn(txn, event)
-            elif event.type == EventTypes.Message:
-                # Insert into the event_search table.
-                self._store_room_message_txn(txn, event)
-            elif event.type == EventTypes.Redaction:
-                # Insert into the redactions table.
-                self._store_redaction(txn, event)
-
-            self._handle_event_relations(txn, event)
-
-        # Insert into the room_memberships table.
-        self._store_room_members_txn(
-            txn,
-            [
-                event
-                for event, _ in events_and_contexts
-                if event.type == EventTypes.Member
-            ],
-            backfilled=backfilled,
-        )
-
-        # Insert event_reference_hashes table.
-        self._store_event_reference_hashes_txn(
-            txn, [event for event, _ in events_and_contexts]
-        )
-
-        state_events_and_contexts = [
-            ec for ec in events_and_contexts if ec[0].is_state()
-        ]
-
-        state_values = []
-        for event, context in state_events_and_contexts:
-            vals = {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "type": event.type,
-                "state_key": event.state_key,
-            }
-
-            # TODO: How does this work with backfilling?
-            if hasattr(event, "replaces_state"):
-                vals["prev_state"] = event.replaces_state
-
-            state_values.append(vals)
-
-        self._simple_insert_many_txn(txn, table="state_events", values=state_values)
-
-        # Prefill the event cache
-        self._add_to_cache(txn, events_and_contexts)
-
-    def _add_to_cache(self, txn, events_and_contexts):
-        to_prefill = []
-
-        rows = []
-        N = 200
-        for i in range(0, len(events_and_contexts), N):
-            ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]}
-            if not ev_map:
-                break
-
-            sql = (
-                "SELECT "
-                " e.event_id as event_id, "
-                " r.redacts as redacts,"
-                " rej.event_id as rejects "
-                " FROM events as e"
-                " LEFT JOIN rejections as rej USING (event_id)"
-                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
-                " WHERE e.event_id IN (%s)"
-            ) % (",".join(["?"] * len(ev_map)),)
-
-            txn.execute(sql, list(ev_map))
-            rows = self.cursor_to_dict(txn)
-            for row in rows:
-                event = ev_map[row["event_id"]]
-                if not row["rejects"] and not row["redacts"]:
-                    to_prefill.append(
-                        _EventCacheEntry(event=event, redacted_event=None)
-                    )
-
-        def prefill():
-            for cache_entry in to_prefill:
-                self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
-
-        txn.call_after(prefill)
-
-    def _store_redaction(self, txn, event):
-        # invalidate the cache for the redacted event
-        txn.call_after(self._invalidate_get_event_cache, event.redacts)
-        txn.execute(
-            "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
-            (event.event_id, event.redacts),
-        )
-
-    @defer.inlineCallbacks
-    def _censor_redactions(self):
-        """Censors all redactions older than the configured period that haven't
-        been censored yet.
-
-        By censor we mean update the event_json table with the redacted event.
-
-        Returns:
-            Deferred
-        """
-
-        if self.hs.config.redaction_retention_period is None:
-            return
-
-        max_pos = yield self.find_first_stream_ordering_after_ts(
-            self._clock.time_msec() - self.hs.config.redaction_retention_period
-        )
-
-        # We fetch all redactions that:
-        #   1. point to an event we have,
-        #   2. has a stream ordering from before the cut off, and
-        #   3. we haven't yet censored.
-        #
-        # This is limited to 100 events to ensure that we don't try and do too
-        # much at once. We'll get called again so this should eventually catch
-        # up.
-        #
-        # We use the range [-max_pos, max_pos] to handle backfilled events,
-        # which are given negative stream ordering.
-        sql = """
-            SELECT redact_event.event_id, redacts FROM redactions
-            INNER JOIN events AS redact_event USING (event_id)
-            INNER JOIN events AS original_event ON (
-                redact_event.room_id = original_event.room_id
-                AND redacts = original_event.event_id
-            )
-            WHERE NOT have_censored
-            AND ? <= redact_event.stream_ordering AND redact_event.stream_ordering <= ?
-            ORDER BY redact_event.stream_ordering ASC
-            LIMIT ?
-        """
-
-        rows = yield self._execute(
-            "_censor_redactions_fetch", None, sql, -max_pos, max_pos, 100
-        )
-
-        updates = []
-
-        for redaction_id, event_id in rows:
-            redaction_event = yield self.get_event(redaction_id, allow_none=True)
-            original_event = yield self.get_event(
-                event_id, allow_rejected=True, allow_none=True
-            )
-
-            # The SQL above ensures that we have both the redaction and
-            # original event, so if the `get_event` calls return None it
-            # means that the redaction wasn't allowed. Either way we know that
-            # the result won't change so we mark the fact that we've checked.
-            if (
-                redaction_event
-                and original_event
-                and original_event.internal_metadata.is_redacted()
-            ):
-                # Redaction was allowed
-                pruned_json = encode_canonical_json(
-                    prune_event_dict(original_event.get_dict())
-                )
-            else:
-                # Redaction wasn't allowed
-                pruned_json = None
-
-            updates.append((redaction_id, event_id, pruned_json))
-
-        def _update_censor_txn(txn):
-            for redaction_id, event_id, pruned_json in updates:
-                if pruned_json:
-                    self._simple_update_one_txn(
-                        txn,
-                        table="event_json",
-                        keyvalues={"event_id": event_id},
-                        updatevalues={"json": pruned_json},
-                    )
-
-                self._simple_update_one_txn(
-                    txn,
-                    table="redactions",
-                    keyvalues={"event_id": redaction_id},
-                    updatevalues={"have_censored": True},
-                )
-
-        yield self.runInteraction("_update_censor_txn", _update_censor_txn)
-
-    @defer.inlineCallbacks
-    def count_daily_messages(self):
-        """
-        Returns an estimate of the number of messages sent in the last day.
-
-        If it has been significantly less or more than one day since the last
-        call to this function, it will return None.
-        """
-
-        def _count_messages(txn):
-            sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
-                WHERE type = 'm.room.message'
-                AND stream_ordering > ?
-            """
-            txn.execute(sql, (self.stream_ordering_day_ago,))
-            count, = txn.fetchone()
-            return count
-
-        ret = yield self.runInteraction("count_messages", _count_messages)
-        return ret
-
-    @defer.inlineCallbacks
-    def count_daily_sent_messages(self):
-        def _count_messages(txn):
-            # This is good enough as if you have silly characters in your own
-            # hostname then thats your own fault.
-            like_clause = "%:" + self.hs.hostname
-
-            sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
-                WHERE type = 'm.room.message'
-                    AND sender LIKE ?
-                AND stream_ordering > ?
-            """
-
-            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            count, = txn.fetchone()
-            return count
-
-        ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
-        return ret
-
-    @defer.inlineCallbacks
-    def count_daily_active_rooms(self):
-        def _count(txn):
-            sql = """
-                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
-                WHERE type = 'm.room.message'
-                AND stream_ordering > ?
-            """
-            txn.execute(sql, (self.stream_ordering_day_ago,))
-            count, = txn.fetchone()
-            return count
-
-        ret = yield self.runInteraction("count_daily_active_rooms", _count)
-        return ret
-
-    def get_current_backfill_token(self):
-        """The current minimum token that backfilled events have reached"""
-        return -self._backfill_id_gen.get_current_token()
-
-    def get_current_events_token(self):
-        """The current maximum token that events have reached"""
-        return self._stream_id_gen.get_current_token()
-
-    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_forward_event_rows(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < event_stream_ordering"
-                " AND event_stream_ordering <= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (last_id, upper_bound))
-            new_event_updates.extend(txn)
-
-            return new_event_updates
-
-        return self.runInteraction(
-            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
-        )
-
-    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_backfill_event_rows(txn):
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (-last_id, -current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > event_stream_ordering"
-                " AND event_stream_ordering >= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (-last_id, -upper_bound))
-            new_event_updates.extend(txn.fetchall())
-
-            return new_event_updates
-
-        return self.runInteraction(
-            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
-        )
-
-    @cached(num_args=5, max_entries=10)
-    def get_all_new_events(
-        self,
-        last_backfill_id,
-        last_forward_id,
-        current_backfill_id,
-        current_forward_id,
-        limit,
-    ):
-        """Get all the new events that have arrived at the server either as
-        new events or as backfilled events"""
-        have_backfill_events = last_backfill_id != current_backfill_id
-        have_forward_events = last_forward_id != current_forward_id
-
-        if not have_backfill_events and not have_forward_events:
-            return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
-        def get_all_new_events_txn(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            if have_forward_events:
-                txn.execute(sql, (last_forward_id, current_forward_id, limit))
-                new_forward_events = txn.fetchall()
-
-                if len(new_forward_events) == limit:
-                    upper_bound = new_forward_events[-1][0]
-                else:
-                    upper_bound = current_forward_id
-
-                sql = (
-                    "SELECT event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                forward_ex_outliers = txn.fetchall()
-            else:
-                new_forward_events = []
-                forward_ex_outliers = []
-
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering DESC"
-                " LIMIT ?"
-            )
-            if have_backfill_events:
-                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
-                new_backfill_events = txn.fetchall()
-
-                if len(new_backfill_events) == limit:
-                    upper_bound = new_backfill_events[-1][0]
-                else:
-                    upper_bound = current_backfill_id
-
-                sql = (
-                    "SELECT -event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (-last_backfill_id, -upper_bound))
-                backward_ex_outliers = txn.fetchall()
-            else:
-                new_backfill_events = []
-                backward_ex_outliers = []
-
-            return AllNewEventsResult(
-                new_forward_events,
-                new_backfill_events,
-                forward_ex_outliers,
-                backward_ex_outliers,
-            )
-
-        return self.runInteraction("get_all_new_events", get_all_new_events_txn)
-
-    def purge_history(self, room_id, token, delete_local_events):
-        """Deletes room history before a certain point
-
-        Args:
-            room_id (str):
-
-            token (str): A topological token to delete events before
-
-            delete_local_events (bool):
-                if True, we will delete local events as well as remote ones
-                (instead of just marking them as outliers and deleting their
-                state groups).
-        """
-
-        return self.runInteraction(
-            "purge_history",
-            self._purge_history_txn,
-            room_id,
-            token,
-            delete_local_events,
-        )
-
-    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
-        token = RoomStreamToken.parse(token_str)
-
-        # Tables that should be pruned:
-        #     event_auth
-        #     event_backward_extremities
-        #     event_edges
-        #     event_forward_extremities
-        #     event_json
-        #     event_push_actions
-        #     event_reference_hashes
-        #     event_search
-        #     event_to_state_groups
-        #     events
-        #     rejections
-        #     room_depth
-        #     state_groups
-        #     state_groups_state
-
-        # we will build a temporary table listing the events so that we don't
-        # have to keep shovelling the list back and forth across the
-        # connection. Annoyingly the python sqlite driver commits the
-        # transaction on CREATE, so let's do this first.
-        #
-        # furthermore, we might already have the table from a previous (failed)
-        # purge attempt, so let's drop the table first.
-
-        txn.execute("DROP TABLE IF EXISTS events_to_purge")
-
-        txn.execute(
-            "CREATE TEMPORARY TABLE events_to_purge ("
-            "    event_id TEXT NOT NULL,"
-            "    should_delete BOOLEAN NOT NULL"
-            ")"
-        )
-
-        # First ensure that we're not about to delete all the forward extremeties
-        txn.execute(
-            "SELECT e.event_id, e.depth FROM events as e "
-            "INNER JOIN event_forward_extremities as f "
-            "ON e.event_id = f.event_id "
-            "AND e.room_id = f.room_id "
-            "WHERE f.room_id = ?",
-            (room_id,),
-        )
-        rows = txn.fetchall()
-        max_depth = max(row[1] for row in rows)
-
-        if max_depth < token.topological:
-            # We need to ensure we don't delete all the events from the database
-            # otherwise we wouldn't be able to send any events (due to not
-            # having any backwards extremeties)
-            raise SynapseError(
-                400, "topological_ordering is greater than forward extremeties"
-            )
-
-        logger.info("[purge] looking for events to delete")
-
-        should_delete_expr = "state_key IS NULL"
-        should_delete_params = ()
-        if not delete_local_events:
-            should_delete_expr += " AND event_id NOT LIKE ?"
-
-            # We include the parameter twice since we use the expression twice
-            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
-
-        should_delete_params += (room_id, token.topological)
-
-        # Note that we insert events that are outliers and aren't going to be
-        # deleted, as nothing will happen to them.
-        txn.execute(
-            "INSERT INTO events_to_purge"
-            " SELECT event_id, %s"
-            " FROM events AS e LEFT JOIN state_events USING (event_id)"
-            " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
-            % (should_delete_expr, should_delete_expr),
-            should_delete_params,
-        )
-
-        # We create the indices *after* insertion as that's a lot faster.
-
-        # create an index on should_delete because later we'll be looking for
-        # the should_delete / shouldn't_delete subsets
-        txn.execute(
-            "CREATE INDEX events_to_purge_should_delete"
-            " ON events_to_purge(should_delete)"
-        )
-
-        # We do joins against events_to_purge for e.g. calculating state
-        # groups to purge, etc., so lets make an index.
-        txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
-
-        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
-        event_rows = txn.fetchall()
-        logger.info(
-            "[purge] found %i events before cutoff, of which %i can be deleted",
-            len(event_rows),
-            sum(1 for e in event_rows if e[1]),
-        )
-
-        logger.info("[purge] Finding new backward extremities")
-
-        # We calculate the new entries for the backward extremeties by finding
-        # events to be purged that are pointed to by events we're not going to
-        # purge.
-        txn.execute(
-            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
-            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
-            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
-            " WHERE ep2.event_id IS NULL"
-        )
-        new_backwards_extrems = txn.fetchall()
-
-        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
-
-        txn.execute(
-            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
-        )
-
-        # Update backward extremeties
-        txn.executemany(
-            "INSERT INTO event_backward_extremities (room_id, event_id)"
-            " VALUES (?, ?)",
-            [(room_id, event_id) for event_id, in new_backwards_extrems],
-        )
-
-        logger.info("[purge] finding redundant state groups")
-
-        # Get all state groups that are referenced by events that are to be
-        # deleted. We then go and check if they are referenced by other events
-        # or state groups, and if not we delete them.
-        txn.execute(
-            """
-            SELECT DISTINCT state_group FROM events_to_purge
-            INNER JOIN event_to_state_groups USING (event_id)
-        """
-        )
-
-        referenced_state_groups = set(sg for sg, in txn)
-        logger.info(
-            "[purge] found %i referenced state groups", len(referenced_state_groups)
-        )
-
-        logger.info("[purge] finding state groups that can be deleted")
-
-        _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups)
-        state_groups_to_delete, remaining_state_groups = _
-
-        logger.info(
-            "[purge] found %i state groups to delete", len(state_groups_to_delete)
-        )
-
-        logger.info(
-            "[purge] de-delta-ing %i remaining state groups",
-            len(remaining_state_groups),
-        )
-
-        # Now we turn the state groups that reference to-be-deleted state
-        # groups to non delta versions.
-        for sg in remaining_state_groups:
-            logger.info("[purge] de-delta-ing remaining state group %s", sg)
-            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
-            curr_state = curr_state[sg]
-
-            self._simple_delete_txn(
-                txn, table="state_groups_state", keyvalues={"state_group": sg}
-            )
-
-            self._simple_delete_txn(
-                txn, table="state_group_edges", keyvalues={"state_group": sg}
-            )
-
-            self._simple_insert_many_txn(
-                txn,
-                table="state_groups_state",
-                values=[
-                    {
-                        "state_group": sg,
-                        "room_id": room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                        "event_id": state_id,
-                    }
-                    for key, state_id in iteritems(curr_state)
-                ],
-            )
-
-        logger.info("[purge] removing redundant state groups")
-        txn.executemany(
-            "DELETE FROM state_groups_state WHERE state_group = ?",
-            ((sg,) for sg in state_groups_to_delete),
-        )
-        txn.executemany(
-            "DELETE FROM state_groups WHERE id = ?",
-            ((sg,) for sg in state_groups_to_delete),
-        )
-
-        logger.info("[purge] removing events from event_to_state_groups")
-        txn.execute(
-            "DELETE FROM event_to_state_groups "
-            "WHERE event_id IN (SELECT event_id from events_to_purge)"
-        )
-        for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
-
-        # Delete all remote non-state events
-        for table in (
-            "events",
-            "event_json",
-            "event_auth",
-            "event_edges",
-            "event_forward_extremities",
-            "event_reference_hashes",
-            "event_search",
-            "rejections",
-        ):
-            logger.info("[purge] removing events from %s", table)
-
-            txn.execute(
-                "DELETE FROM %s WHERE event_id IN ("
-                "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,)
-            )
-
-        # event_push_actions lacks an index on event_id, and has one on
-        # (room_id, event_id) instead.
-        for table in ("event_push_actions",):
-            logger.info("[purge] removing events from %s", table)
-
-            txn.execute(
-                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
-                "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,),
-                (room_id,),
-            )
-
-        # Mark all state and own events as outliers
-        logger.info("[purge] marking remaining events as outliers")
-        txn.execute(
-            "UPDATE events SET outlier = ?"
-            " WHERE event_id IN ("
-            "    SELECT event_id FROM events_to_purge "
-            "    WHERE NOT should_delete"
-            ")",
-            (True,),
-        )
-
-        # synapse tries to take out an exclusive lock on room_depth whenever it
-        # persists events (because upsert), and once we run this update, we
-        # will block that for the rest of our transaction.
-        #
-        # So, let's stick it at the end so that we don't block event
-        # persistence.
-        #
-        # We do this by calculating the minimum depth of the backwards
-        # extremities. However, the events in event_backward_extremities
-        # are ones we don't have yet so we need to look at the events that
-        # point to it via event_edges table.
-        txn.execute(
-            """
-            SELECT COALESCE(MIN(depth), 0)
-            FROM event_backward_extremities AS eb
-            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
-            INNER JOIN events AS e ON e.event_id = eg.event_id
-            WHERE eb.room_id = ?
-        """,
-            (room_id,),
-        )
-        min_depth, = txn.fetchone()
-
-        logger.info("[purge] updating room_depth to %d", min_depth)
-
-        txn.execute(
-            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (min_depth, room_id),
-        )
-
-        # finally, drop the temp table. this will commit the txn in sqlite,
-        # so make sure to keep this actually last.
-        txn.execute("DROP TABLE events_to_purge")
-
-        logger.info("[purge] done")
-
-    def _find_unreferenced_groups_during_purge(self, txn, state_groups):
-        """Used when purging history to figure out which state groups can be
-        deleted and which need to be de-delta'ed (due to one of its prev groups
-        being scheduled for deletion).
-
-        Args:
-            txn
-            state_groups (set[int]): Set of state groups referenced by events
-                that are going to be deleted.
-
-        Returns:
-            tuple[set[int], set[int]]: The set of state groups that can be
-            deleted and the set of state groups that need to be de-delta'ed
-        """
-        # Graph of state group -> previous group
-        graph = {}
-
-        # Set of events that we have found to be referenced by events
-        referenced_groups = set()
-
-        # Set of state groups we've already seen
-        state_groups_seen = set(state_groups)
-
-        # Set of state groups to handle next.
-        next_to_search = set(state_groups)
-        while next_to_search:
-            # We bound size of groups we're looking up at once, to stop the
-            # SQL query getting too big
-            if len(next_to_search) < 100:
-                current_search = next_to_search
-                next_to_search = set()
-            else:
-                current_search = set(itertools.islice(next_to_search, 100))
-                next_to_search -= current_search
-
-            # Check if state groups are referenced
-            sql = """
-                SELECT DISTINCT state_group FROM event_to_state_groups
-                LEFT JOIN events_to_purge AS ep USING (event_id)
-                WHERE state_group IN (%s) AND ep.event_id IS NULL
-            """ % (
-                ",".join("?" for _ in current_search),
-            )
-            txn.execute(sql, list(current_search))
-
-            referenced = set(sg for sg, in txn)
-            referenced_groups |= referenced
-
-            # We don't continue iterating up the state group graphs for state
-            # groups that are referenced.
-            current_search -= referenced
-
-            rows = self._simple_select_many_txn(
-                txn,
-                table="state_group_edges",
-                column="prev_state_group",
-                iterable=current_search,
-                keyvalues={},
-                retcols=("prev_state_group", "state_group"),
-            )
-
-            prevs = set(row["state_group"] for row in rows)
-            # We don't bother re-handling groups we've already seen
-            prevs -= state_groups_seen
-            next_to_search |= prevs
-            state_groups_seen |= prevs
-
-            for row in rows:
-                # Note: Each state group can have at most one prev group
-                graph[row["state_group"]] = row["prev_state_group"]
-
-        to_delete = state_groups_seen - referenced_groups
-
-        to_dedelta = set()
-        for sg in referenced_groups:
-            prev_sg = graph.get(sg)
-            if prev_sg and prev_sg in to_delete:
-                to_dedelta.add(sg)
-
-        return to_delete, to_dedelta
-
-    def purge_room(self, room_id):
-        """Deletes all record of a room
-
-        Args:
-            room_id (str):
-        """
-
-        return self.runInteraction("purge_room", self._purge_room_txn, room_id)
-
-    def _purge_room_txn(self, txn, room_id):
-        # first we have to delete the state groups states
-        logger.info("[purge] removing %s from state_groups_state", room_id)
-
-        txn.execute(
-            """
-            DELETE FROM state_groups_state WHERE state_group IN (
-              SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
-              WHERE events.room_id=?
-            )
-            """,
-            (room_id,),
-        )
-
-        # ... and the state group edges
-        logger.info("[purge] removing %s from state_group_edges", room_id)
-
-        txn.execute(
-            """
-            DELETE FROM state_group_edges WHERE state_group IN (
-              SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
-              WHERE events.room_id=?
-            )
-            """,
-            (room_id,),
-        )
-
-        # ... and the state groups
-        logger.info("[purge] removing %s from state_groups", room_id)
-
-        txn.execute(
-            """
-            DELETE FROM state_groups WHERE id IN (
-              SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
-              WHERE events.room_id=?
-            )
-            """,
-            (room_id,),
-        )
-
-        # and then tables which lack an index on room_id but have one on event_id
-        for table in (
-            "event_auth",
-            "event_edges",
-            "event_push_actions_staging",
-            "event_reference_hashes",
-            "event_relations",
-            "event_to_state_groups",
-            "redactions",
-            "rejections",
-            "state_events",
-        ):
-            logger.info("[purge] removing %s from %s", room_id, table)
-
-            txn.execute(
-                """
-                DELETE FROM %s WHERE event_id IN (
-                  SELECT event_id FROM events WHERE room_id=?
-                )
-                """
-                % (table,),
-                (room_id,),
-            )
-
-        # and finally, the tables with an index on room_id (or no useful index)
-        for table in (
-            "current_state_events",
-            "event_backward_extremities",
-            "event_forward_extremities",
-            "event_json",
-            "event_push_actions",
-            "event_search",
-            "events",
-            "group_rooms",
-            "public_room_list_stream",
-            "receipts_graph",
-            "receipts_linearized",
-            "room_aliases",
-            "room_depth",
-            "room_memberships",
-            "room_stats_state",
-            "room_stats_current",
-            "room_stats_historical",
-            "room_stats_earliest_token",
-            "rooms",
-            "stream_ordering_to_exterm",
-            "topics",
-            "users_in_public_rooms",
-            "users_who_share_private_rooms",
-            # no useful index, but let's clear them anyway
-            "appservice_room_list",
-            "e2e_room_keys",
-            "event_push_summary",
-            "pusher_throttle",
-            "group_summary_rooms",
-            "local_invites",
-            "room_account_data",
-            "room_tags",
-        ):
-            logger.info("[purge] removing %s from %s", room_id, table)
-            txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
-
-        # Other tables we do NOT need to clear out:
-        #
-        #  - blocked_rooms
-        #    This is important, to make sure that we don't accidentally rejoin a blocked
-        #    room after it was purged
-        #
-        #  - user_directory
-        #    This has a room_id column, but it is unused
-        #
-
-        # Other tables that we might want to consider clearing out include:
-        #
-        #  - event_reports
-        #       Given that these are intended for abuse management my initial
-        #       inclination is to leave them in place.
-        #
-        #  - current_state_delta_stream
-        #  - ex_outlier_stream
-        #  - room_tags_revisions
-        #       The problem with these is that they are largeish and there is no room_id
-        #       index on them. In any case we should be clearing out 'stream' tables
-        #       periodically anyway (#5888)
-
-        # TODO: we could probably usefully do a bunch of cache invalidation here
-
-        logger.info("[purge] done")
-
-    @defer.inlineCallbacks
-    def is_event_after(self, event_id1, event_id2):
-        """Returns True if event_id1 is after event_id2 in the stream
-        """
-        to_1, so_1 = yield self._get_event_ordering(event_id1)
-        to_2, so_2 = yield self._get_event_ordering(event_id2)
-        return (to_1, so_1) > (to_2, so_2)
-
-    @cachedInlineCallbacks(max_entries=5000)
-    def _get_event_ordering(self, event_id):
-        res = yield self._simple_select_one(
-            table="events",
-            retcols=["topological_ordering", "stream_ordering"],
-            keyvalues={"event_id": event_id},
-            allow_none=True,
-        )
-
-        if not res:
-            raise SynapseError(404, "Could not find event %s" % (event_id,))
-
-        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
-
-    def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
-        def get_all_updated_current_state_deltas_txn(txn):
-            sql = """
-                SELECT stream_id, room_id, type, state_key, event_id
-                FROM current_state_delta_stream
-                WHERE ? < stream_id AND stream_id <= ?
-                ORDER BY stream_id ASC LIMIT ?
-            """
-            txn.execute(sql, (from_token, to_token, limit))
-            return txn.fetchall()
-
-        return self.runInteraction(
-            "get_all_updated_current_state_deltas",
-            get_all_updated_current_state_deltas_txn,
-        )
-
-
-AllNewEventsResult = namedtuple(
-    "AllNewEventsResult",
-    [
-        "new_forward_events",
-        "new_backfill_events",
-        "forward_ex_outliers",
-        "backward_ex_outliers",
-    ],
-)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index e72f89e446..4769b21529 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -14,208 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import logging
 
-import six
-
 import attr
-from signedjson.key import decode_verify_key_bytes
-
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
-
-from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
-
 
 @attr.s(slots=True, frozen=True)
 class FetchKeyResult(object):
     verify_key = attr.ib()  # VerifyKey: the key itself
     valid_until_ts = attr.ib()  # int: how long we can use this key for
-
-
-class KeyStore(SQLBaseStore):
-    """Persistence for signature verification keys
-    """
-
-    @cached()
-    def _get_server_verify_key(self, server_name_and_key_id):
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
-    )
-    def get_server_verify_keys(self, server_name_and_key_ids):
-        """
-        Args:
-            server_name_and_key_ids (iterable[Tuple[str, str]]):
-                iterable of (server_name, key-id) tuples to fetch keys for
-
-        Returns:
-            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
-                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
-                unknown
-        """
-        keys = {}
-
-        def _get_keys(txn, batch):
-            """Processes a batch of keys to fetch, and adds the result to `keys`."""
-
-            # batch_iter always returns tuples so it's safe to do len(batch)
-            sql = (
-                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
-                "FROM server_signature_keys WHERE 1=0"
-            ) + " OR (server_name=? AND key_id=?)" * len(batch)
-
-            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
-
-            for row in txn:
-                server_name, key_id, key_bytes, ts_valid_until_ms = row
-
-                if ts_valid_until_ms is None:
-                    # Old keys may be stored with a ts_valid_until_ms of null,
-                    # in which case we treat this as if it was set to `0`, i.e.
-                    # it won't match key requests that define a minimum
-                    # `ts_valid_until_ms`.
-                    ts_valid_until_ms = 0
-
-                res = FetchKeyResult(
-                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
-                    valid_until_ts=ts_valid_until_ms,
-                )
-                keys[(server_name, key_id)] = res
-
-        def _txn(txn):
-            for batch in batch_iter(server_name_and_key_ids, 50):
-                _get_keys(txn, batch)
-            return keys
-
-        return self.runInteraction("get_server_verify_keys", _txn)
-
-    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
-        """Stores NACL verification keys for remote servers.
-        Args:
-            from_server (str): Where the verification keys were looked up
-            ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
-                keys to be stored. Each entry is a triplet of
-                (server_name, key_id, key).
-        """
-        key_values = []
-        value_values = []
-        invalidations = []
-        for server_name, key_id, fetch_result in verify_keys:
-            key_values.append((server_name, key_id))
-            value_values.append(
-                (
-                    from_server,
-                    ts_added_ms,
-                    fetch_result.valid_until_ts,
-                    db_binary_type(fetch_result.verify_key.encode()),
-                )
-            )
-            # invalidate takes a tuple corresponding to the params of
-            # _get_server_verify_key. _get_server_verify_key only takes one
-            # param, which is itself the 2-tuple (server_name, key_id).
-            invalidations.append((server_name, key_id))
-
-        def _invalidate(res):
-            f = self._get_server_verify_key.invalidate
-            for i in invalidations:
-                f((i,))
-            return res
-
-        return self.runInteraction(
-            "store_server_verify_keys",
-            self._simple_upsert_many_txn,
-            table="server_signature_keys",
-            key_names=("server_name", "key_id"),
-            key_values=key_values,
-            value_names=(
-                "from_server",
-                "ts_added_ms",
-                "ts_valid_until_ms",
-                "verify_key",
-            ),
-            value_values=value_values,
-        ).addCallback(_invalidate)
-
-    def store_server_keys_json(
-        self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
-    ):
-        """Stores the JSON bytes for a set of keys from a server
-        The JSON should be signed by the originating server, the intermediate
-        server, and by this server. Updates the value for the
-        (server_name, key_id, from_server) triplet if one already existed.
-        Args:
-            server_name (str): The name of the server.
-            key_id (str): The identifer of the key this JSON is for.
-            from_server (str): The server this JSON was fetched from.
-            ts_now_ms (int): The time now in milliseconds.
-            ts_valid_until_ms (int): The time when this json stops being valid.
-            key_json (bytes): The encoded JSON.
-        """
-        return self._simple_upsert(
-            table="server_keys_json",
-            keyvalues={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-            },
-            values={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-                "ts_added_ms": ts_now_ms,
-                "ts_valid_until_ms": ts_expires_ms,
-                "key_json": db_binary_type(key_json_bytes),
-            },
-            desc="store_server_keys_json",
-        )
-
-    def get_server_keys_json(self, server_keys):
-        """Retrive the key json for a list of server_keys and key ids.
-        If no keys are found for a given server, key_id and source then
-        that server, key_id, and source triplet entry will be an empty list.
-        The JSON is returned as a byte array so that it can be efficiently
-        used in an HTTP response.
-        Args:
-            server_keys (list): List of (server_name, key_id, source) triplets.
-        Returns:
-            Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
-                Dict mapping (server_name, key_id, source) triplets to lists of dicts
-        """
-
-        def _get_server_keys_json_txn(txn):
-            results = {}
-            for server_name, key_id, from_server in server_keys:
-                keyvalues = {"server_name": server_name}
-                if key_id is not None:
-                    keyvalues["key_id"] = key_id
-                if from_server is not None:
-                    keyvalues["from_server"] = from_server
-                rows = self._simple_select_list_txn(
-                    txn,
-                    "server_keys_json",
-                    keyvalues=keyvalues,
-                    retcols=(
-                        "key_id",
-                        "from_server",
-                        "ts_added_ms",
-                        "ts_valid_until_ms",
-                        "key_json",
-                    ),
-                )
-                results[(server_name, key_id, from_server)] = rows
-            return results
-
-        return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
new file mode 100644
index 0000000000..f159400a87
--- /dev/null
+++ b/synapse/storage/persist_events.py
@@ -0,0 +1,794 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from collections import deque, namedtuple
+from typing import Iterable, List, Optional, Set, Tuple
+
+from six import iteritems
+from six.moves import range
+
+from prometheus_client import Counter, Histogram
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
+from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main.events import DeltaState
+from synapse.types import StateMap
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+    "synapse_storage_events_state_delta_single_event", ""
+)
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+    "synapse_storage_events_state_delta_reuse_delta", ""
+)
+
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+    "synapse_storage_events_forward_extremities_persisted",
+    "Number of forward extremities for each new event",
+    buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+    "synapse_storage_events_stale_forward_extremities_persisted",
+    "Number of unchanged forward extremities for each new event",
+    buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
+class _EventPeristenceQueue(object):
+    """Queues up events so that they can be persisted in bulk with only one
+    concurrent transaction per room.
+    """
+
+    _EventPersistQueueItem = namedtuple(
+        "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+    )
+
+    def __init__(self):
+        self._event_persist_queues = {}
+        self._currently_persisting_rooms = set()
+
+    def add_to_queue(self, room_id, events_and_contexts, backfilled):
+        """Add events to the queue, with the given persist_event options.
+
+        NB: due to the normal usage pattern of this method, it does *not*
+        follow the synapse logcontext rules, and leaves the logcontext in
+        place whether or not the returned deferred is ready.
+
+        Args:
+            room_id (str):
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+
+        Returns:
+            defer.Deferred: a deferred which will resolve once the events are
+                persisted. Runs its callbacks *without* a logcontext.
+        """
+        queue = self._event_persist_queues.setdefault(room_id, deque())
+        if queue:
+            # if the last item in the queue has the same `backfilled` setting,
+            # we can just add these new events to that item.
+            end_item = queue[-1]
+            if end_item.backfilled == backfilled:
+                end_item.events_and_contexts.extend(events_and_contexts)
+                return end_item.deferred.observe()
+
+        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+
+        queue.append(
+            self._EventPersistQueueItem(
+                events_and_contexts=events_and_contexts,
+                backfilled=backfilled,
+                deferred=deferred,
+            )
+        )
+
+        return deferred.observe()
+
+    def handle_queue(self, room_id, per_item_callback):
+        """Attempts to handle the queue for a room if not already being handled.
+
+        The given callback will be invoked with for each item in the queue,
+        of type _EventPersistQueueItem. The per_item_callback will continuously
+        be called with new items, unless the queue becomnes empty. The return
+        value of the function will be given to the deferreds waiting on the item,
+        exceptions will be passed to the deferreds as well.
+
+        This function should therefore be called whenever anything is added
+        to the queue.
+
+        If another callback is currently handling the queue then it will not be
+        invoked.
+        """
+
+        if room_id in self._currently_persisting_rooms:
+            return
+
+        self._currently_persisting_rooms.add(room_id)
+
+        async def handle_queue_loop():
+            try:
+                queue = self._get_drainining_queue(room_id)
+                for item in queue:
+                    try:
+                        ret = await per_item_callback(item)
+                    except Exception:
+                        with PreserveLoggingContext():
+                            item.deferred.errback()
+                    else:
+                        with PreserveLoggingContext():
+                            item.deferred.callback(ret)
+            finally:
+                queue = self._event_persist_queues.pop(room_id, None)
+                if queue:
+                    self._event_persist_queues[room_id] = queue
+                self._currently_persisting_rooms.discard(room_id)
+
+        # set handle_queue_loop off in the background
+        run_as_background_process("persist_events", handle_queue_loop)
+
+    def _get_drainining_queue(self, room_id):
+        queue = self._event_persist_queues.setdefault(room_id, deque())
+
+        try:
+            while True:
+                yield queue.popleft()
+        except IndexError:
+            # Queue has been drained.
+            pass
+
+
+class EventsPersistenceStorage(object):
+    """High level interface for handling persisting newly received events.
+
+    Takes care of batching up events by room, and calculating the necessary
+    current state and forward extremity changes.
+    """
+
+    def __init__(self, hs, stores: DataStores):
+        # We ultimately want to split out the state store from the main store,
+        # so we use separate variables here even though they point to the same
+        # store for now.
+        self.main_store = stores.main
+        self.state_store = stores.state
+        self.persist_events_store = stores.persist_events
+
+        self._clock = hs.get_clock()
+        self.is_mine_id = hs.is_mine_id
+        self._event_persist_queue = _EventPeristenceQueue()
+        self._state_resolution_handler = hs.get_state_resolution_handler()
+
+    @defer.inlineCallbacks
+    def persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
+        """
+        Write events to the database
+        Args:
+            events_and_contexts: list of tuples of (event, context)
+            backfilled: Whether the results are retrieved from federation
+                via backfill or not. Used to determine if they're "new" events
+                which might update the current state etc.
+
+        Returns:
+            Deferred[int]: the stream ordering of the latest persisted event
+        """
+        partitioned = {}
+        for event, ctx in events_and_contexts:
+            partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+        deferreds = []
+        for room_id, evs_ctxs in iteritems(partitioned):
+            d = self._event_persist_queue.add_to_queue(
+                room_id, evs_ctxs, backfilled=backfilled
+            )
+            deferreds.append(d)
+
+        for room_id in partitioned:
+            self._maybe_start_persisting(room_id)
+
+        yield make_deferred_yieldable(
+            defer.gatherResults(deferreds, consumeErrors=True)
+        )
+
+        max_persisted_id = yield self.main_store.get_current_events_token()
+
+        return max_persisted_id
+
+    @defer.inlineCallbacks
+    def persist_event(
+        self, event: FrozenEvent, context: EventContext, backfilled: bool = False
+    ):
+        """
+        Returns:
+            Deferred[Tuple[int, int]]: the stream ordering of ``event``,
+            and the stream ordering of the latest persisted event
+        """
+        deferred = self._event_persist_queue.add_to_queue(
+            event.room_id, [(event, context)], backfilled=backfilled
+        )
+
+        self._maybe_start_persisting(event.room_id)
+
+        yield make_deferred_yieldable(deferred)
+
+        max_persisted_id = yield self.main_store.get_current_events_token()
+        return (event.internal_metadata.stream_ordering, max_persisted_id)
+
+    def _maybe_start_persisting(self, room_id: str):
+        async def persisting_queue(item):
+            with Measure(self._clock, "persist_events"):
+                await self._persist_events(
+                    item.events_and_contexts, backfilled=item.backfilled
+                )
+
+        self._event_persist_queue.handle_queue(room_id, persisting_queue)
+
+    async def _persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
+        """Calculates the change to current state and forward extremities, and
+        persists the given events and with those updates.
+        """
+        if not events_and_contexts:
+            return
+
+        chunks = [
+            events_and_contexts[x : x + 100]
+            for x in range(0, len(events_and_contexts), 100)
+        ]
+
+        for chunk in chunks:
+            # We can't easily parallelize these since different chunks
+            # might contain the same event. :(
+
+            # NB: Assumes that we are only persisting events for one room
+            # at a time.
+
+            # map room_id->list[event_ids] giving the new forward
+            # extremities in each room
+            new_forward_extremeties = {}
+
+            # map room_id->(type,state_key)->event_id tracking the full
+            # state in each room after adding these events.
+            # This is simply used to prefill the get_current_state_ids
+            # cache
+            current_state_for_room = {}
+
+            # map room_id->(to_delete, to_insert) where to_delete is a list
+            # of type/state keys to remove from current state, and to_insert
+            # is a map (type,key)->event_id giving the state delta in each
+            # room
+            state_delta_for_room = {}
+
+            # Set of remote users which were in rooms the server has left. We
+            # should check if we still share any rooms and if not we mark their
+            # device lists as stale.
+            potentially_left_users = set()  # type: Set[str]
+
+            if not backfilled:
+                with Measure(self._clock, "_calculate_state_and_extrem"):
+                    # Work out the new "current state" for each room.
+                    # We do this by working out what the new extremities are and then
+                    # calculating the state from that.
+                    events_by_room = {}
+                    for event, context in chunk:
+                        events_by_room.setdefault(event.room_id, []).append(
+                            (event, context)
+                        )
+
+                    for room_id, ev_ctx_rm in iteritems(events_by_room):
+                        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
+                            room_id
+                        )
+                        new_latest_event_ids = await self._calculate_new_extremities(
+                            room_id, ev_ctx_rm, latest_event_ids
+                        )
+
+                        latest_event_ids = set(latest_event_ids)
+                        if new_latest_event_ids == latest_event_ids:
+                            # No change in extremities, so no change in state
+                            continue
+
+                        # there should always be at least one forward extremity.
+                        # (except during the initial persistence of the send_join
+                        # results, in which case there will be no existing
+                        # extremities, so we'll `continue` above and skip this bit.)
+                        assert new_latest_event_ids, "No forward extremities left!"
+
+                        new_forward_extremeties[room_id] = new_latest_event_ids
+
+                        len_1 = (
+                            len(latest_event_ids) == 1
+                            and len(new_latest_event_ids) == 1
+                        )
+                        if len_1:
+                            all_single_prev_not_state = all(
+                                len(event.prev_event_ids()) == 1
+                                and not event.is_state()
+                                for event, ctx in ev_ctx_rm
+                            )
+                            # Don't bother calculating state if they're just
+                            # a long chain of single ancestor non-state events.
+                            if all_single_prev_not_state:
+                                continue
+
+                        state_delta_counter.inc()
+                        if len(new_latest_event_ids) == 1:
+                            state_delta_single_event_counter.inc()
+
+                            # This is a fairly handwavey check to see if we could
+                            # have guessed what the delta would have been when
+                            # processing one of these events.
+                            # What we're interested in is if the latest extremities
+                            # were the same when we created the event as they are
+                            # now. When this server creates a new event (as opposed
+                            # to receiving it over federation) it will use the
+                            # forward extremities as the prev_events, so we can
+                            # guess this by looking at the prev_events and checking
+                            # if they match the current forward extremities.
+                            for ev, _ in ev_ctx_rm:
+                                prev_event_ids = set(ev.prev_event_ids())
+                                if latest_event_ids == prev_event_ids:
+                                    state_delta_reuse_delta_counter.inc()
+                                    break
+
+                        logger.debug("Calculating state delta for room %s", room_id)
+                        with Measure(
+                            self._clock, "persist_events.get_new_state_after_events"
+                        ):
+                            res = await self._get_new_state_after_events(
+                                room_id,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
+                            )
+                            current_state, delta_ids = res
+
+                        # If either are not None then there has been a change,
+                        # and we need to work out the delta (or use that
+                        # given)
+                        delta = None
+                        if delta_ids is not None:
+                            # If there is a delta we know that we've
+                            # only added or replaced state, never
+                            # removed keys entirely.
+                            delta = DeltaState([], delta_ids)
+                        elif current_state is not None:
+                            with Measure(
+                                self._clock, "persist_events.calculate_state_delta"
+                            ):
+                                delta = await self._calculate_state_delta(
+                                    room_id, current_state
+                                )
+
+                        if delta:
+                            # If we have a change of state then lets check
+                            # whether we're actually still a member of the room,
+                            # or if our last user left. If we're no longer in
+                            # the room then we delete the current state and
+                            # extremities.
+                            is_still_joined = await self._is_server_still_joined(
+                                room_id,
+                                ev_ctx_rm,
+                                delta,
+                                current_state,
+                                potentially_left_users,
+                            )
+                            if not is_still_joined:
+                                logger.info("Server no longer in room %s", room_id)
+                                latest_event_ids = []
+                                current_state = {}
+                                delta.no_longer_in_room = True
+
+                            state_delta_for_room[room_id] = delta
+
+                        # If we have the current_state then lets prefill
+                        # the cache with it.
+                        if current_state is not None:
+                            current_state_for_room[room_id] = current_state
+
+            await self.persist_events_store._persist_events_and_state_updates(
+                chunk,
+                current_state_for_room=current_state_for_room,
+                state_delta_for_room=state_delta_for_room,
+                new_forward_extremeties=new_forward_extremeties,
+                backfilled=backfilled,
+            )
+
+            await self._handle_potentially_left_users(potentially_left_users)
+
+    async def _calculate_new_extremities(
+        self,
+        room_id: str,
+        event_contexts: List[Tuple[FrozenEvent, EventContext]],
+        latest_event_ids: List[str],
+    ):
+        """Calculates the new forward extremities for a room given events to
+        persist.
+
+        Assumes that we are only persisting events for one room at a time.
+        """
+
+        # we're only interested in new events which aren't outliers and which aren't
+        # being rejected.
+        new_events = [
+            event
+            for event, ctx in event_contexts
+            if not event.internal_metadata.is_outlier()
+            and not ctx.rejected
+            and not event.internal_metadata.is_soft_failed()
+        ]
+
+        latest_event_ids = set(latest_event_ids)
+
+        # start with the existing forward extremities
+        result = set(latest_event_ids)
+
+        # add all the new events to the list
+        result.update(event.event_id for event in new_events)
+
+        # Now remove all events which are prev_events of any of the new events
+        result.difference_update(
+            e_id for event in new_events for e_id in event.prev_event_ids()
+        )
+
+        # Remove any events which are prev_events of any existing events.
+        existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
+            result
+        )
+        result.difference_update(existing_prevs)
+
+        # Finally handle the case where the new events have soft-failed prev
+        # events. If they do we need to remove them and their prev events,
+        # otherwise we end up with dangling extremities.
+        existing_prevs = await self.persist_events_store._get_prevs_before_rejected(
+            e_id for event in new_events for e_id in event.prev_event_ids()
+        )
+        result.difference_update(existing_prevs)
+
+        # We only update metrics for events that change forward extremities
+        # (e.g. we ignore backfill/outliers/etc)
+        if result != latest_event_ids:
+            forward_extremities_counter.observe(len(result))
+            stale = latest_event_ids & result
+            stale_forward_extremities_counter.observe(len(stale))
+
+        return result
+
+    async def _get_new_state_after_events(
+        self,
+        room_id: str,
+        events_context: List[Tuple[FrozenEvent, EventContext]],
+        old_latest_event_ids: Iterable[str],
+        new_latest_event_ids: Iterable[str],
+    ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+        """Calculate the current state dict after adding some new events to
+        a room
+
+        Args:
+            room_id (str):
+                room to which the events are being added. Used for logging etc
+
+            events_context (list[(EventBase, EventContext)]):
+                events and contexts which are being added to the room
+
+            old_latest_event_ids (iterable[str]):
+                the old forward extremities for the room.
+
+            new_latest_event_ids (iterable[str]):
+                the new forward extremities for the room.
+
+        Returns:
+            Returns a tuple of two state maps, the first being the full new current
+            state and the second being the delta to the existing current state.
+            If both are None then there has been no change.
+
+            If there has been a change then we only return the delta if its
+            already been calculated. Conversely if we do know the delta then
+            the new current state is only returned if we've already calculated
+            it.
+        """
+        # map from state_group to ((type, key) -> event_id) state map
+        state_groups_map = {}
+
+        # Map from (prev state group, new state group) -> delta state dict
+        state_group_deltas = {}
+
+        for ev, ctx in events_context:
+            if ctx.state_group is None:
+                # This should only happen for outlier events.
+                if not ev.internal_metadata.is_outlier():
+                    raise Exception(
+                        "Context for new event %s has no state "
+                        "group" % (ev.event_id,)
+                    )
+                continue
+
+            if ctx.state_group in state_groups_map:
+                continue
+
+            # We're only interested in pulling out state that has already
+            # been cached in the context. We'll pull stuff out of the DB later
+            # if necessary.
+            current_state_ids = ctx.get_cached_current_state_ids()
+            if current_state_ids is not None:
+                state_groups_map[ctx.state_group] = current_state_ids
+
+            if ctx.prev_group:
+                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+        # We need to map the event_ids to their state groups. First, let's
+        # check if the event is one we're persisting, in which case we can
+        # pull the state group from its context.
+        # Otherwise we need to pull the state group from the database.
+
+        # Set of events we need to fetch groups for. (We know none of the old
+        # extremities are going to be in events_context).
+        missing_event_ids = set(old_latest_event_ids)
+
+        event_id_to_state_group = {}
+        for event_id in new_latest_event_ids:
+            # First search in the list of new events we're adding.
+            for ev, ctx in events_context:
+                if event_id == ev.event_id and ctx.state_group is not None:
+                    event_id_to_state_group[event_id] = ctx.state_group
+                    break
+            else:
+                # If we couldn't find it, then we'll need to pull
+                # the state from the database
+                missing_event_ids.add(event_id)
+
+        if missing_event_ids:
+            # Now pull out the state groups for any missing events from DB
+            event_to_groups = await self.main_store._get_state_group_for_events(
+                missing_event_ids
+            )
+            event_id_to_state_group.update(event_to_groups)
+
+        # State groups of old_latest_event_ids
+        old_state_groups = {
+            event_id_to_state_group[evid] for evid in old_latest_event_ids
+        }
+
+        # State groups of new_latest_event_ids
+        new_state_groups = {
+            event_id_to_state_group[evid] for evid in new_latest_event_ids
+        }
+
+        # If they old and new groups are the same then we don't need to do
+        # anything.
+        if old_state_groups == new_state_groups:
+            return None, None
+
+        if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+            # If we're going from one state group to another, lets check if
+            # we have a delta for that transition. If we do then we can just
+            # return that.
+
+            new_state_group = next(iter(new_state_groups))
+            old_state_group = next(iter(old_state_groups))
+
+            delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
+            if delta_ids is not None:
+                # We have a delta from the existing to new current state,
+                # so lets just return that. If we happen to already have
+                # the current state in memory then lets also return that,
+                # but it doesn't matter if we don't.
+                new_state = state_groups_map.get(new_state_group)
+                return new_state, delta_ids
+
+        # Now that we have calculated new_state_groups we need to get
+        # their state IDs so we can resolve to a single state set.
+        missing_state = new_state_groups - set(state_groups_map)
+        if missing_state:
+            group_to_state = await self.state_store._get_state_for_groups(missing_state)
+            state_groups_map.update(group_to_state)
+
+        if len(new_state_groups) == 1:
+            # If there is only one state group, then we know what the current
+            # state is.
+            return state_groups_map[new_state_groups.pop()], None
+
+        # Ok, we need to defer to the state handler to resolve our state sets.
+
+        state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
+
+        events_map = {ev.event_id: ev for ev, _ in events_context}
+
+        # We need to get the room version, which is in the create event.
+        # Normally that'd be in the database, but its also possible that we're
+        # currently trying to persist it.
+        room_version = None
+        for ev, _ in events_context:
+            if ev.type == EventTypes.Create and ev.state_key == "":
+                room_version = ev.content.get("room_version", "1")
+                break
+
+        if not room_version:
+            room_version = await self.main_store.get_room_version_id(room_id)
+
+        logger.debug("calling resolve_state_groups from preserve_events")
+        res = await self._state_resolution_handler.resolve_state_groups(
+            room_id,
+            room_version,
+            state_groups,
+            events_map,
+            state_res_store=StateResolutionStore(self.main_store),
+        )
+
+        return res.state, None
+
+    async def _calculate_state_delta(
+        self, room_id: str, current_state: StateMap[str]
+    ) -> DeltaState:
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+        """
+        existing_state = await self.main_store.get_current_state_ids(room_id)
+
+        to_delete = [key for key in existing_state if key not in current_state]
+
+        to_insert = {
+            key: ev_id
+            for key, ev_id in iteritems(current_state)
+            if ev_id != existing_state.get(key)
+        }
+
+        return DeltaState(to_delete=to_delete, to_insert=to_insert)
+
+    async def _is_server_still_joined(
+        self,
+        room_id: str,
+        ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+        delta: DeltaState,
+        current_state: Optional[StateMap[str]],
+        potentially_left_users: Set[str],
+    ) -> bool:
+        """Check if the server will still be joined after the given events have
+        been persised.
+
+        Args:
+            room_id
+            ev_ctx_rm
+            delta: The delta of current state between what is in the database
+                and what the new current state will be.
+            current_state: The new current state if it already been calculated,
+                otherwise None.
+            potentially_left_users: If the server has left the room, then joined
+                remote users will be added to this set to indicate that the
+                server may no longer be sharing a room with them.
+        """
+
+        if not any(
+            self.is_mine_id(state_key)
+            for typ, state_key in itertools.chain(delta.to_delete, delta.to_insert)
+            if typ == EventTypes.Member
+        ):
+            # There have been no changes to membership of our users, so nothing
+            # has changed and we assume we're still in the room.
+            return True
+
+        # Check if any of the given events are a local join that appear in the
+        # current state
+        events_to_check = []  # Event IDs that aren't an event we're persisting
+        for (typ, state_key), event_id in delta.to_insert.items():
+            if typ != EventTypes.Member or not self.is_mine_id(state_key):
+                continue
+
+            for event, _ in ev_ctx_rm:
+                if event_id == event.event_id:
+                    if event.membership == Membership.JOIN:
+                        return True
+
+            # The event is not in `ev_ctx_rm`, so we need to pull it out of
+            # the DB.
+            events_to_check.append(event_id)
+
+        # Check if any of the changes that we don't have events for are joins.
+        if events_to_check:
+            rows = await self.main_store.get_membership_from_event_ids(events_to_check)
+            is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+            if is_still_joined:
+                return True
+
+        # None of the new state events are local joins, so we check the database
+        # to see if there are any other local users in the room. We ignore users
+        # whose state has changed as we've already their new state above.
+        users_to_ignore = [
+            state_key
+            for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+            if typ == EventTypes.Member and self.is_mine_id(state_key)
+        ]
+
+        if await self.main_store.is_local_host_in_room_ignoring_users(
+            room_id, users_to_ignore
+        ):
+            return True
+
+        # The server will leave the room, so we go and find out which remote
+        # users will still be joined when we leave.
+        if current_state is None:
+            current_state = await self.main_store.get_current_state_ids(room_id)
+            current_state = dict(current_state)
+            for key in delta.to_delete:
+                current_state.pop(key, None)
+
+            current_state.update(delta.to_insert)
+
+        remote_event_ids = [
+            event_id
+            for (typ, state_key,), event_id in current_state.items()
+            if typ == EventTypes.Member and not self.is_mine_id(state_key)
+        ]
+        rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+        potentially_left_users.update(
+            row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+        )
+
+        return False
+
+    async def _handle_potentially_left_users(self, user_ids: Set[str]):
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        joined_users = await self.main_store.get_users_server_still_shares_room_with(
+            user_ids
+        )
+        left_users = user_ids - joined_users
+
+        for user_id in left_users:
+            await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+    async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+        """Mark the invite has having been rejected even though we failed to
+        create a leave event for it.
+        """
+        return await self.persist_events_store.locally_reject_invite(user_id, room_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e96eed8a6d..9cc3b51fe6 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -14,20 +14,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import fnmatch
 import imp
 import logging
 import os
 import re
+from collections import Counter
+from typing import TextIO
+
+import attr
 
 from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.types import Cursor
 
 logger = logging.getLogger(__name__)
 
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 56
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
+# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
+SCHEMA_VERSION = 58
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -40,7 +47,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
     pass
 
 
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
     """Prepares a database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
@@ -53,7 +60,10 @@ def prepare_database(db_conn, database_engine, config):
         config (synapse.config.homeserver.HomeServerConfig|None):
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
+        data_stores (list[str]): The name of the data stores that will be used
+            with this database. Defaults to all data stores.
     """
+
     try:
         cur = db_conn.cursor()
         version_info = _get_or_create_schema_state(cur, database_engine)
@@ -65,13 +75,22 @@ def prepare_database(db_conn, database_engine, config):
                 if user_version != SCHEMA_VERSION:
                     # If we don't pass in a config file then we are expecting to
                     # have already upgraded the DB.
-                    raise UpgradeDatabaseException("Database needs to be upgraded")
+                    raise UpgradeDatabaseException(
+                        "Expected database schema version %i but got %i"
+                        % (SCHEMA_VERSION, user_version)
+                    )
             else:
                 _upgrade_existing_database(
-                    cur, user_version, delta_files, upgraded, database_engine, config
+                    cur,
+                    user_version,
+                    delta_files,
+                    upgraded,
+                    database_engine,
+                    config,
+                    data_stores=data_stores,
                 )
         else:
-            _setup_new_database(cur, database_engine)
+            _setup_new_database(cur, database_engine, data_stores=data_stores)
 
         # check if any of our configured dynamic modules want a database
         if config is not None:
@@ -84,9 +103,10 @@ def prepare_database(db_conn, database_engine, config):
         raise
 
 
-def _setup_new_database(cur, database_engine):
+def _setup_new_database(cur, database_engine, data_stores):
     """Sets up the database by finding a base set of "full schemas" and then
-    applying any necessary deltas.
+    applying any necessary deltas, including schemas from the given data
+    stores.
 
     The "full_schemas" directory has subdirectories named after versions. This
     function searches for the highest version less than or equal to
@@ -111,52 +131,83 @@ def _setup_new_database(cur, database_engine):
 
     In the example foo.sql and bar.sql would be run, and then any delta files
     for versions strictly greater than 11.
+
+    Note: we apply the full schemas and deltas from the top level `schema/`
+    folder as well those in the data stores specified.
+
+    Args:
+        cur (Cursor): a database cursor
+        database_engine (DatabaseEngine)
+        data_stores (list[str]): The names of the data stores to instantiate
+            on the given database.
     """
-    current_dir = os.path.join(dir_path, "schema", "full_schemas")
-    directory_entries = os.listdir(current_dir)
 
-    valid_dirs = []
-    pattern = re.compile(r"^\d+(\.sql)?$")
+    # We're about to set up a brand new database so we check that its
+    # configured to our liking.
+    database_engine.check_new_database(cur)
 
-    if isinstance(database_engine, PostgresEngine):
-        specific = "postgres"
-    else:
-        specific = "sqlite"
+    current_dir = os.path.join(dir_path, "schema", "full_schemas")
+    directory_entries = os.listdir(current_dir)
 
-    specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$")
+    # First we find the highest full schema version we have
+    valid_versions = []
 
     for filename in directory_entries:
-        match = pattern.match(filename) or specific_pattern.match(filename)
-        abs_path = os.path.join(current_dir, filename)
-        if match and os.path.isdir(abs_path):
-            ver = int(match.group(0))
-            if ver <= SCHEMA_VERSION:
-                valid_dirs.append((ver, abs_path))
-        else:
-            logger.debug("Ignoring entry '%s' in 'full_schemas'", filename)
+        try:
+            ver = int(filename)
+        except ValueError:
+            continue
+
+        if ver <= SCHEMA_VERSION:
+            valid_versions.append(ver)
 
-    if not valid_dirs:
+    if not valid_versions:
         raise PrepareDatabaseException(
             "Could not find a suitable base set of full schemas"
         )
 
-    max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
+    max_current_ver = max(valid_versions)
 
     logger.debug("Initialising schema v%d", max_current_ver)
 
-    directory_entries = os.listdir(sql_dir)
+    # Now lets find all the full schema files, both in the global schema and
+    # in data store schemas.
+    directories = [os.path.join(current_dir, str(max_current_ver))]
+    directories.extend(
+        os.path.join(
+            dir_path,
+            "data_stores",
+            data_store,
+            "schema",
+            "full_schemas",
+            str(max_current_ver),
+        )
+        for data_store in data_stores
+    )
+
+    directory_entries = []
+    for directory in directories:
+        directory_entries.extend(
+            _DirectoryListing(file_name, os.path.join(directory, file_name))
+            for file_name in os.listdir(directory)
+        )
+
+    if isinstance(database_engine, PostgresEngine):
+        specific = "postgres"
+    else:
+        specific = "sqlite"
 
-    for filename in sorted(
-        fnmatch.filter(directory_entries, "*.sql")
-        + fnmatch.filter(directory_entries, "*.sql." + specific)
-    ):
-        sql_loc = os.path.join(sql_dir, filename)
-        logger.debug("Applying schema %s", sql_loc)
-        executescript(cur, sql_loc)
+    directory_entries.sort()
+    for entry in directory_entries:
+        if entry.file_name.endswith(".sql") or entry.file_name.endswith(
+            ".sql." + specific
+        ):
+            logger.debug("Applying schema %s", entry.absolute_path)
+            executescript(cur, entry.absolute_path)
 
     cur.execute(
         database_engine.convert_param_style(
-            "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+            "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
         ),
         (max_current_ver, False),
     )
@@ -168,6 +219,7 @@ def _setup_new_database(cur, database_engine):
         upgraded=False,
         database_engine=database_engine,
         config=None,
+        data_stores=data_stores,
         is_empty=True,
     )
 
@@ -179,6 +231,7 @@ def _upgrade_existing_database(
     upgraded,
     database_engine,
     config,
+    data_stores,
     is_empty=False,
 ):
     """Upgrades an existing database.
@@ -215,6 +268,10 @@ def _upgrade_existing_database(
     only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
     some arbitrary order.
 
+    Note: we apply the delta files from the specified data stores as well as
+    those in the top-level schema. We apply all delta files across data stores
+    for a version before applying those in the next version.
+
     Args:
         cur (Cursor)
         current_version (int): The current version of the schema.
@@ -224,7 +281,19 @@ def _upgrade_existing_database(
             applied deltas or from full schema file. If `True` the function
             will never apply delta files for the given `current_version`, since
             the current_version wasn't generated by applying those delta files.
+        database_engine (DatabaseEngine)
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            None if we are initialising a blank database, otherwise the application
+            config
+        data_stores (list[str]): The names of the data stores to instantiate
+            on the given database.
+        is_empty (bool): Is this a blank database? I.e. do we need to run the
+            upgrade portions of the delta scripts.
     """
+    if is_empty:
+        assert not applied_delta_files
+    else:
+        assert config
 
     if current_version > SCHEMA_VERSION:
         raise ValueError(
@@ -232,6 +301,13 @@ def _upgrade_existing_database(
             + "new for the server to understand"
         )
 
+    # some of the deltas assume that config.server_name is set correctly, so now
+    # is a good time to run the sanity check.
+    if not is_empty and "main" in data_stores:
+        from synapse.storage.data_stores.main import check_database_before_upgrade
+
+        check_database_before_upgrade(cur, database_engine, config)
+
     start_ver = current_version
     if not upgraded:
         start_ver += 1
@@ -248,24 +324,65 @@ def _upgrade_existing_database(
     for v in range(start_ver, SCHEMA_VERSION + 1):
         logger.info("Upgrading schema to v%d", v)
 
+        # We need to search both the global and per data store schema
+        # directories for schema updates.
+
+        # First we find the directories to search in
         delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+        directories = [delta_dir]
+        for data_store in data_stores:
+            directories.append(
+                os.path.join(
+                    dir_path, "data_stores", data_store, "schema", "delta", str(v)
+                )
+            )
 
-        try:
-            directory_entries = os.listdir(delta_dir)
-        except OSError:
-            logger.exception("Could not open delta dir for version %d", v)
-            raise UpgradeDatabaseException(
-                "Could not open delta dir for version %d" % (v,)
+        # Used to check if we have any duplicate file names
+        file_name_counter = Counter()
+
+        # Now find which directories have anything of interest.
+        directory_entries = []
+        for directory in directories:
+            logger.debug("Looking for schema deltas in %s", directory)
+            try:
+                file_names = os.listdir(directory)
+                directory_entries.extend(
+                    _DirectoryListing(file_name, os.path.join(directory, file_name))
+                    for file_name in file_names
+                )
+
+                for file_name in file_names:
+                    file_name_counter[file_name] += 1
+            except FileNotFoundError:
+                # Data stores can have empty entries for a given version delta.
+                pass
+            except OSError:
+                raise UpgradeDatabaseException(
+                    "Could not open delta dir for version %d: %s" % (v, directory)
+                )
+
+        duplicates = {
+            file_name for file_name, count in file_name_counter.items() if count > 1
+        }
+        if duplicates:
+            # We don't support using the same file name in the same delta version.
+            raise PrepareDatabaseException(
+                "Found multiple delta files with the same name in v%d: %s"
+                % (v, duplicates,)
             )
 
+        # We sort to ensure that we apply the delta files in a consistent
+        # order (to avoid bugs caused by inconsistent directory listing order)
         directory_entries.sort()
-        for file_name in directory_entries:
+        for entry in directory_entries:
+            file_name = entry.file_name
             relative_path = os.path.join(str(v), file_name)
-            logger.debug("Found file: %s", relative_path)
+            absolute_path = entry.absolute_path
+
+            logger.debug("Found file: %s (%s)", relative_path, absolute_path)
             if relative_path in applied_delta_files:
                 continue
 
-            absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
             root_name, ext = os.path.splitext(file_name)
             if ext == ".py":
                 # This is a python upgrade module. We need to import into some
@@ -352,7 +469,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         ),
         (modname,),
     )
-    applied_deltas = set(d for d, in cur)
+    applied_deltas = {d for d, in cur}
     for (name, stream) in names_and_streams:
         if name in applied_deltas:
             continue
@@ -364,13 +481,12 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
             )
 
         logger.info("applying schema %s for %s", name, modname)
-        for statement in get_statements(stream):
-            cur.execute(statement)
+        execute_statements_from_stream(cur, stream)
 
         # Mark as done.
         cur.execute(
             database_engine.convert_param_style(
-                "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+                "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
             ),
             (modname, name),
         )
@@ -423,8 +539,12 @@ def get_statements(f):
 
 def executescript(txn, schema_path):
     with open(schema_path, "r") as f:
-        for statement in get_statements(f):
-            txn.execute(statement)
+        execute_statements_from_stream(txn, f)
+
+
+def execute_statements_from_stream(cur: Cursor, f: TextIO):
+    for statement in get_statements(f):
+        cur.execute(statement)
 
 
 def _get_or_create_schema_state(txn, database_engine):
@@ -448,3 +568,16 @@ def _get_or_create_schema_state(txn, database_engine):
         return current_version, applied_deltas, upgraded
 
     return None
+
+
+@attr.s()
+class _DirectoryListing(object):
+    """Helper class to store schema file name and the
+    absolute path to it.
+
+    These entries get sorted, so for consistency we want to ensure that
+    `file_name` attr is kept first.
+    """
+
+    file_name = attr.ib()
+    absolute_path = attr.ib()
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 5db6f2d84a..18a462f0ee 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -15,13 +15,7 @@
 
 from collections import namedtuple
 
-from twisted.internet import defer
-
 from synapse.api.constants import PresenceState
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
-
-from ._base import SQLBaseStore
 
 
 class UserPresenceState(
@@ -73,133 +67,3 @@ class UserPresenceState(
             status_msg=None,
             currently_active=False,
         )
-
-
-class PresenceStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def update_presence(self, presence_states):
-        stream_ordering_manager = self._presence_id_gen.get_next_mult(
-            len(presence_states)
-        )
-
-        with stream_ordering_manager as stream_orderings:
-            yield self.runInteraction(
-                "update_presence",
-                self._update_presence_txn,
-                stream_orderings,
-                presence_states,
-            )
-
-        return stream_orderings[-1], self._presence_id_gen.get_current_token()
-
-    def _update_presence_txn(self, txn, stream_orderings, presence_states):
-        for stream_id, state in zip(stream_orderings, presence_states):
-            txn.call_after(
-                self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
-            )
-            txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
-
-        # Actually insert new rows
-        self._simple_insert_many_txn(
-            txn,
-            table="presence_stream",
-            values=[
-                {
-                    "stream_id": stream_id,
-                    "user_id": state.user_id,
-                    "state": state.state,
-                    "last_active_ts": state.last_active_ts,
-                    "last_federation_update_ts": state.last_federation_update_ts,
-                    "last_user_sync_ts": state.last_user_sync_ts,
-                    "status_msg": state.status_msg,
-                    "currently_active": state.currently_active,
-                }
-                for state in presence_states
-            ],
-        )
-
-        # Delete old rows to stop database from getting really big
-        sql = (
-            "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
-        )
-
-        for states in batch_iter(presence_states, 50):
-            args = [stream_id]
-            args.extend(s.user_id for s in states)
-            txn.execute(sql % (",".join("?" for _ in states),), args)
-
-    def get_all_presence_updates(self, last_id, current_id):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_presence_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, user_id, state, last_active_ts,"
-                " last_federation_update_ts, last_user_sync_ts, status_msg,"
-                " currently_active"
-                " FROM presence_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-            )
-            txn.execute(sql, (last_id, current_id))
-            return txn.fetchall()
-
-        return self.runInteraction(
-            "get_all_presence_updates", get_all_presence_updates_txn
-        )
-
-    @cached()
-    def _get_presence_for_user(self, user_id):
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="_get_presence_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def get_presence_for_users(self, user_ids):
-        rows = yield self._simple_select_many_batch(
-            table="presence_stream",
-            column="user_id",
-            iterable=user_ids,
-            keyvalues={},
-            retcols=(
-                "user_id",
-                "state",
-                "last_active_ts",
-                "last_federation_update_ts",
-                "last_user_sync_ts",
-                "status_msg",
-                "currently_active",
-            ),
-            desc="get_presence_for_users",
-        )
-
-        for row in rows:
-            row["currently_active"] = bool(row["currently_active"])
-
-        return {row["user_id"]: UserPresenceState(**row) for row in rows}
-
-    def get_current_presence_token(self):
-        return self._presence_id_gen.get_current_token()
-
-    def allow_presence_visible(self, observed_localpart, observer_userid):
-        return self._simple_insert(
-            table="presence_allow_inbound",
-            values={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="allow_presence_visible",
-            or_ignore=True,
-        )
-
-    def disallow_presence_visible(self, observed_localpart, observer_userid):
-        return self._simple_delete_one(
-            table="presence_allow_inbound",
-            keyvalues={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="disallow_presence_visible",
-        )
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
new file mode 100644
index 0000000000..fdc0abf5cf
--- /dev/null
+++ b/synapse/storage/purge_events.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStorage(object):
+    """High level interface for purging rooms and event history.
+    """
+
+    def __init__(self, hs, stores):
+        self.stores = stores
+
+    @defer.inlineCallbacks
+    def purge_room(self, room_id: str):
+        """Deletes all record of a room
+        """
+
+        state_groups_to_delete = yield self.stores.main.purge_room(room_id)
+        yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+
+    @defer.inlineCallbacks
+    def purge_history(self, room_id, token, delete_local_events):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
+
+            token (str): A topological token to delete events before
+
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
+        """
+        state_groups = yield self.stores.main.purge_history(
+            room_id, token, delete_local_events
+        )
+
+        logger.info("[purge] finding state groups that can be deleted")
+
+        sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+
+        yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+
+    @defer.inlineCallbacks
+    def _find_unreferenced_groups(self, state_groups):
+        """Used when purging history to figure out which state groups can be
+        deleted.
+
+        Args:
+            state_groups (set[int]): Set of state groups referenced by events
+                that are going to be deleted.
+
+        Returns:
+            Deferred[set[int]] The set of state groups that can be deleted.
+        """
+        # Graph of state group -> previous group
+        graph = {}
+
+        # Set of events that we have found to be referenced by events
+        referenced_groups = set()
+
+        # Set of state groups we've already seen
+        state_groups_seen = set(state_groups)
+
+        # Set of state groups to handle next.
+        next_to_search = set(state_groups)
+        while next_to_search:
+            # We bound size of groups we're looking up at once, to stop the
+            # SQL query getting too big
+            if len(next_to_search) < 100:
+                current_search = next_to_search
+                next_to_search = set()
+            else:
+                current_search = set(itertools.islice(next_to_search, 100))
+                next_to_search -= current_search
+
+            referenced = yield self.stores.main.get_referenced_state_groups(
+                current_search
+            )
+            referenced_groups |= referenced
+
+            # We don't continue iterating up the state group graphs for state
+            # groups that are referenced.
+            current_search -= referenced
+
+            edges = yield self.stores.state.get_previous_state_groups(current_search)
+
+            prevs = set(edges.values())
+            # We don't bother re-handling groups we've already seen
+            prevs -= state_groups_seen
+            next_to_search |= prevs
+            state_groups_seen |= prevs
+
+            graph.update(edges)
+
+        to_delete = state_groups_seen - referenced_groups
+
+        return to_delete
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index a6517c4cf3..f47cec0d86 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,708 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
-import logging
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.push.baserules import list_with_base_rules
-from synapse.storage.appservice import ApplicationServiceWorkerStore
-from synapse.storage.pusher import PusherWorkerStore
-from synapse.storage.receipts import ReceiptsWorkerStore
-from synapse.storage.roommember import RoomMemberWorkerStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from ._base import SQLBaseStore
-
-logger = logging.getLogger(__name__)
-
-
-def _load_rules(rawrules, enabled_map):
-    ruleslist = []
-    for rawrule in rawrules:
-        rule = dict(rawrule)
-        rule["conditions"] = json.loads(rawrule["conditions"])
-        rule["actions"] = json.loads(rawrule["actions"])
-        ruleslist.append(rule)
-
-    # We're going to be mutating this a lot, so do a deep copy
-    rules = list(list_with_base_rules(ruleslist))
-
-    for i, rule in enumerate(rules):
-        rule_id = rule["rule_id"]
-        if rule_id in enabled_map:
-            if rule.get("enabled", True) != bool(enabled_map[rule_id]):
-                # Rules are cached across users.
-                rule = dict(rule)
-                rule["enabled"] = bool(enabled_map[rule_id])
-                rules[i] = rule
-
-    return rules
-
-
-class PushRulesWorkerStore(
-    ApplicationServiceWorkerStore,
-    ReceiptsWorkerStore,
-    PusherWorkerStore,
-    RoomMemberWorkerStore,
-    SQLBaseStore,
-):
-    """This is an abstract base class where subclasses must implement
-    `get_max_push_rules_stream_id` which can be called in the initializer.
-    """
-
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
-    def __init__(self, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
-
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
-            db_conn,
-            "push_rules_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self.get_max_push_rules_stream_id(),
-        )
-
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache",
-            push_rules_id,
-            prefilled_cache=push_rules_prefill,
-        )
-
-    @abc.abstractmethod
-    def get_max_push_rules_stream_id(self):
-        """Get the position of the push rules stream.
-
-        Returns:
-            int
-        """
-        raise NotImplementedError()
-
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_for_user(self, user_id):
-        rows = yield self._simple_select_list(
-            table="push_rules",
-            keyvalues={"user_name": user_id},
-            retcols=(
-                "user_name",
-                "rule_id",
-                "priority_class",
-                "priority",
-                "conditions",
-                "actions",
-            ),
-            desc="get_push_rules_enabled_for_user",
-        )
-
-        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
-        enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
-
-        rules = _load_rules(rows, enabled_map)
-
-        return rules
-
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_enabled_for_user(self, user_id):
-        results = yield self._simple_select_list(
-            table="push_rules_enable",
-            keyvalues={"user_name": user_id},
-            retcols=("user_name", "rule_id", "enabled"),
-            desc="get_push_rules_enabled_for_user",
-        )
-        return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
-
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
-
-    @cachedList(
-        cached_method_name="get_push_rules_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def bulk_get_push_rules(self, user_ids):
-        if not user_ids:
-            return {}
-
-        results = {user_id: [] for user_id in user_ids}
-
-        rows = yield self._simple_select_many_batch(
-            table="push_rules",
-            column="user_name",
-            iterable=user_ids,
-            retcols=("*",),
-            desc="bulk_get_push_rules",
-        )
-
-        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
-        for row in rows:
-            results.setdefault(row["user_name"], []).append(row)
-
-        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
-
-        for user_id, rules in results.items():
-            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
-
-        return results
-
-    @defer.inlineCallbacks
-    def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
-        """Move a single push rule from one room to another for a specific user.
-
-        Args:
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user the push rule belongs to.
-            rule (Dict): A push rule.
-        """
-        # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
-        new_rule_id = rule_id_scope + "/" + new_room_id
-
-        # Change room id in each condition
-        for condition in rule.get("conditions", []):
-            if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
-
-        # Add the rule for the new room
-        yield self.add_push_rule(
-            user_id=user_id,
-            rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
-        )
-
-        # Delete push rule for the old room
-        yield self.delete_push_rule(user_id, rule["rule_id"])
-
-    @defer.inlineCallbacks
-    def move_push_rules_from_room_to_room_for_user(
-        self, old_room_id, new_room_id, user_id
-    ):
-        """Move all of the push rules from one room to another for a specific
-        user.
-
-        Args:
-            old_room_id (str): ID of the old room.
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user to copy push rules for.
-        """
-        # Retrieve push rules for this user
-        user_push_rules = yield self.get_push_rules_for_user(user_id)
-
-        # Get rules relating to the old room, move them to the new room, then
-        # delete them from the old room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
-            if any(
-                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
-                for c in conditions
-            ):
-                self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
-    @defer.inlineCallbacks
-    def bulk_get_push_rules_for_room(self, event, context):
-        state_group = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = yield context.get_current_state_ids(self)
-        result = yield self._bulk_get_push_rules_for_room(
-            event.room_id, state_group, current_state_ids, event=event
-        )
-        return result
-
-    @cachedInlineCallbacks(num_args=2, cache_context=True)
-    def _bulk_get_push_rules_for_room(
-        self, room_id, state_group, current_state_ids, cache_context, event=None
-    ):
-        # We don't use `state_group`, its there so that we can cache based
-        # on it. However, its important that its never None, since two current_state's
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        # We also will want to generate notifs for other people in the room so
-        # their unread countss are correct in the event stream, but to avoid
-        # generating them for bot / AS users etc, we only do so for people who've
-        # sent a read receipt into the room.
-
-        users_in_room = yield self._get_joined_users_from_context(
-            room_id,
-            state_group,
-            current_state_ids,
-            on_invalidate=cache_context.invalidate,
-            event=event,
-        )
-
-        # We ignore app service users for now. This is so that we don't fill
-        # up the `get_if_users_have_pushers` cache with AS entries that we
-        # know don't have pushers, nor even read receipts.
-        local_users_in_room = set(
-            u
-            for u in users_in_room
-            if self.hs.is_mine_id(u)
-            and not self.get_if_app_services_interested_in_user(u)
-        )
-
-        # users in the room who have pushers need to get push rules run because
-        # that's how their pushers work
-        if_users_with_pushers = yield self.get_if_users_have_pushers(
-            local_users_in_room, on_invalidate=cache_context.invalidate
-        )
-        user_ids = set(
-            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-        )
-
-        users_with_receipts = yield self.get_users_with_read_receipts_in_room(
-            room_id, on_invalidate=cache_context.invalidate
-        )
-
-        # any users with pushers must be ours: they have pushers
-        for uid in users_with_receipts:
-            if uid in local_users_in_room:
-                user_ids.add(uid)
-
-        rules_by_user = yield self.bulk_get_push_rules(
-            user_ids, on_invalidate=cache_context.invalidate
-        )
-
-        rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
-
-        return rules_by_user
-
-    @cachedList(
-        cached_method_name="get_push_rules_enabled_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def bulk_get_push_rules_enabled(self, user_ids):
-        if not user_ids:
-            return {}
-
-        results = {user_id: {} for user_id in user_ids}
-
-        rows = yield self._simple_select_many_batch(
-            table="push_rules_enable",
-            column="user_name",
-            iterable=user_ids,
-            retcols=("user_name", "rule_id", "enabled"),
-            desc="bulk_get_push_rules_enabled",
-        )
-        for row in rows:
-            enabled = bool(row["enabled"])
-            results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
-        return results
-
-
-class PushRuleStore(PushRulesWorkerStore):
-    @defer.inlineCallbacks
-    def add_push_rule(
-        self,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions,
-        actions,
-        before=None,
-        after=None,
-    ):
-        conditions_json = json.dumps(conditions)
-        actions_json = json.dumps(actions)
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            if before or after:
-                yield self.runInteraction(
-                    "_add_push_rule_relative_txn",
-                    self._add_push_rule_relative_txn,
-                    stream_id,
-                    event_stream_ordering,
-                    user_id,
-                    rule_id,
-                    priority_class,
-                    conditions_json,
-                    actions_json,
-                    before,
-                    after,
-                )
-            else:
-                yield self.runInteraction(
-                    "_add_push_rule_highest_priority_txn",
-                    self._add_push_rule_highest_priority_txn,
-                    stream_id,
-                    event_stream_ordering,
-                    user_id,
-                    rule_id,
-                    priority_class,
-                    conditions_json,
-                    actions_json,
-                )
-
-    def _add_push_rule_relative_txn(
-        self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-        before,
-        after,
-    ):
-        # Lock the table since otherwise we'll have annoying races between the
-        # SELECT here and the UPSERT below.
-        self.database_engine.lock_table(txn, "push_rules")
-
-        relative_to_rule = before or after
-
-        res = self._simple_select_one_txn(
-            txn,
-            table="push_rules",
-            keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
-            retcols=["priority_class", "priority"],
-            allow_none=True,
-        )
-
-        if not res:
-            raise RuleNotFoundException(
-                "before/after rule not found: %s" % (relative_to_rule,)
-            )
-
-        base_priority_class = res["priority_class"]
-        base_rule_priority = res["priority"]
-
-        if base_priority_class != priority_class:
-            raise InconsistentRuleException(
-                "Given priority class does not match class of relative rule"
-            )
-
-        if before:
-            # Higher priority rules are executed first, So adding a rule before
-            # a rule means giving it a higher priority than that rule.
-            new_rule_priority = base_rule_priority + 1
-        else:
-            # We increment the priority of the existing rules to make space for
-            # the new rule. Therefore if we want this rule to appear after
-            # an existing rule we give it the priority of the existing rule,
-            # and then increment the priority of the existing rule.
-            new_rule_priority = base_rule_priority
-
-        sql = (
-            "UPDATE push_rules SET priority = priority + 1"
-            " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
-        )
-
-        txn.execute(sql, (user_id, priority_class, new_rule_priority))
-
-        self._upsert_push_rule_txn(
-            txn,
-            stream_id,
-            event_stream_ordering,
-            user_id,
-            rule_id,
-            priority_class,
-            new_rule_priority,
-            conditions_json,
-            actions_json,
-        )
-
-    def _add_push_rule_highest_priority_txn(
-        self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-    ):
-        # Lock the table since otherwise we'll have annoying races between the
-        # SELECT here and the UPSERT below.
-        self.database_engine.lock_table(txn, "push_rules")
-
-        # find the highest priority rule in that class
-        sql = (
-            "SELECT COUNT(*), MAX(priority) FROM push_rules"
-            " WHERE user_name = ? and priority_class = ?"
-        )
-        txn.execute(sql, (user_id, priority_class))
-        res = txn.fetchall()
-        (how_many, highest_prio) = res[0]
-
-        new_prio = 0
-        if how_many > 0:
-            new_prio = highest_prio + 1
-
-        self._upsert_push_rule_txn(
-            txn,
-            stream_id,
-            event_stream_ordering,
-            user_id,
-            rule_id,
-            priority_class,
-            new_prio,
-            conditions_json,
-            actions_json,
-        )
-
-    def _upsert_push_rule_txn(
-        self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        priority,
-        conditions_json,
-        actions_json,
-        update_stream=True,
-    ):
-        """Specialised version of _simple_upsert_txn that picks a push_rule_id
-        using the _push_rule_id_gen if it needs to insert the rule. It assumes
-        that the "push_rules" table is locked"""
-
-        sql = (
-            "UPDATE push_rules"
-            " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
-            " WHERE user_name = ? AND rule_id = ?"
-        )
-
-        txn.execute(
-            sql,
-            (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
-        )
-
-        if txn.rowcount == 0:
-            # We didn't update a row with the given rule_id so insert one
-            push_rule_id = self._push_rule_id_gen.get_next()
-
-            self._simple_insert_txn(
-                txn,
-                table="push_rules",
-                values={
-                    "id": push_rule_id,
-                    "user_name": user_id,
-                    "rule_id": rule_id,
-                    "priority_class": priority_class,
-                    "priority": priority,
-                    "conditions": conditions_json,
-                    "actions": actions_json,
-                },
-            )
-
-        if update_stream:
-            self._insert_push_rules_update_txn(
-                txn,
-                stream_id,
-                event_stream_ordering,
-                user_id,
-                rule_id,
-                op="ADD",
-                data={
-                    "priority_class": priority_class,
-                    "priority": priority,
-                    "conditions": conditions_json,
-                    "actions": actions_json,
-                },
-            )
-
-    @defer.inlineCallbacks
-    def delete_push_rule(self, user_id, rule_id):
-        """
-        Delete a push rule. Args specify the row to be deleted and can be
-        any of the columns in the push_rule table, but below are the
-        standard ones
-
-        Args:
-            user_id (str): The matrix ID of the push rule owner
-            rule_id (str): The rule_id of the rule to be deleted
-        """
-
-        def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
-            self._simple_delete_one_txn(
-                txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
-            )
-
-            self._insert_push_rules_update_txn(
-                txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
-            )
-
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
-                "delete_push_rule",
-                delete_push_rule_txn,
-                stream_id,
-                event_stream_ordering,
-            )
-
-    @defer.inlineCallbacks
-    def set_push_rule_enabled(self, user_id, rule_id, enabled):
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
-                "_set_push_rule_enabled_txn",
-                self._set_push_rule_enabled_txn,
-                stream_id,
-                event_stream_ordering,
-                user_id,
-                rule_id,
-                enabled,
-            )
-
-    def _set_push_rule_enabled_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
-    ):
-        new_id = self._push_rules_enable_id_gen.get_next()
-        self._simple_upsert_txn(
-            txn,
-            "push_rules_enable",
-            {"user_name": user_id, "rule_id": rule_id},
-            {"enabled": 1 if enabled else 0},
-            {"id": new_id},
-        )
-
-        self._insert_push_rules_update_txn(
-            txn,
-            stream_id,
-            event_stream_ordering,
-            user_id,
-            rule_id,
-            op="ENABLE" if enabled else "DISABLE",
-        )
-
-    @defer.inlineCallbacks
-    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
-        actions_json = json.dumps(actions)
-
-        def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
-            if is_default_rule:
-                # Add a dummy rule to the rules table with the user specified
-                # actions.
-                priority_class = -1
-                priority = 1
-                self._upsert_push_rule_txn(
-                    txn,
-                    stream_id,
-                    event_stream_ordering,
-                    user_id,
-                    rule_id,
-                    priority_class,
-                    priority,
-                    "[]",
-                    actions_json,
-                    update_stream=False,
-                )
-            else:
-                self._simple_update_one_txn(
-                    txn,
-                    "push_rules",
-                    {"user_name": user_id, "rule_id": rule_id},
-                    {"actions": actions_json},
-                )
-
-            self._insert_push_rules_update_txn(
-                txn,
-                stream_id,
-                event_stream_ordering,
-                user_id,
-                rule_id,
-                op="ACTIONS",
-                data={"actions": actions_json},
-            )
-
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
-                "set_push_rule_actions",
-                set_push_rule_actions_txn,
-                stream_id,
-                event_stream_ordering,
-            )
-
-    def _insert_push_rules_update_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
-    ):
-        values = {
-            "stream_id": stream_id,
-            "event_stream_ordering": event_stream_ordering,
-            "user_id": user_id,
-            "rule_id": rule_id,
-            "op": op,
-        }
-        if data is not None:
-            values.update(data)
-
-        self._simple_insert_txn(txn, "push_rules_stream", values=values)
-
-        txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
-        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
-        txn.call_after(
-            self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
-        )
-
-    def get_all_push_rule_updates(self, last_id, current_id, limit):
-        """Get all the push rules changes that have happend on the server"""
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_push_rule_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
-                " op, priority_class, priority, conditions, actions"
-                " FROM push_rules_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
-
-        return self.runInteraction(
-            "get_all_push_rule_updates", get_all_push_rule_updates_txn
-        )
-
-    def get_push_rules_stream_token(self):
-        """Get the position of the push rules stream.
-        Returns a pair of a stream id for the push_rules stream and the
-        room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_current_token()
-
-    def get_max_push_rules_stream_id(self):
-        return self.get_push_rules_stream_token()[0]
-
 
 class RuleNotFoundException(Exception):
     pass
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index fcb5f2f23a..d471ec9860 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,11 +17,7 @@ import logging
 
 import attr
 
-from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.stream import generate_pagination_where_clause
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
 
@@ -113,358 +109,3 @@ class AggregationPaginationToken(object):
 
     def as_tuple(self):
         return attr.astuple(self)
-
-
-class RelationsWorkerStore(SQLBaseStore):
-    @cached(tree=True)
-    def get_relations_for_event(
-        self,
-        event_id,
-        relation_type=None,
-        event_type=None,
-        aggregation_key=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
-        """Get a list of relations for an event, ordered by topological ordering.
-
-        Args:
-            event_id (str): Fetch events that relate to this event ID.
-            relation_type (str|None): Only fetch events with this relation
-                type, if given.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            aggregation_key (str|None): Only fetch events with this aggregation
-                key, if given.
-            limit (int): Only fetch the most recent `limit` events.
-            direction (str): Whether to fetch the most recent first (`"b"`) or
-                the oldest first (`"f"`).
-            from_token (RelationPaginationToken|None): Fetch rows from the given
-                token, or from the start if None.
-            to_token (RelationPaginationToken|None): Fetch rows up to the given
-                token, or up to the end if None.
-
-        Returns:
-            Deferred[PaginationChunk]: List of event IDs that match relations
-            requested. The rows are of the form `{"event_id": "..."}`.
-        """
-
-        where_clause = ["relates_to_id = ?"]
-        where_args = [event_id]
-
-        if relation_type is not None:
-            where_clause.append("relation_type = ?")
-            where_args.append(relation_type)
-
-        if event_type is not None:
-            where_clause.append("type = ?")
-            where_args.append(event_type)
-
-        if aggregation_key:
-            where_clause.append("aggregation_key = ?")
-            where_args.append(aggregation_key)
-
-        pagination_clause = generate_pagination_where_clause(
-            direction=direction,
-            column_names=("topological_ordering", "stream_ordering"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
-            engine=self.database_engine,
-        )
-
-        if pagination_clause:
-            where_clause.append(pagination_clause)
-
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        sql = """
-            SELECT event_id, topological_ordering, stream_ordering
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE %s
-            ORDER BY topological_ordering %s, stream_ordering %s
-            LIMIT ?
-        """ % (
-            " AND ".join(where_clause),
-            order,
-            order,
-        )
-
-        def _get_recent_references_for_event_txn(txn):
-            txn.execute(sql, where_args + [limit + 1])
-
-            last_topo_id = None
-            last_stream_id = None
-            events = []
-            for row in txn:
-                events.append({"event_id": row[0]})
-                last_topo_id = row[1]
-                last_stream_id = row[2]
-
-            next_batch = None
-            if len(events) > limit and last_topo_id and last_stream_id:
-                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
-
-            return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
-            )
-
-        return self.runInteraction(
-            "get_recent_references_for_event", _get_recent_references_for_event_txn
-        )
-
-    @cached(tree=True)
-    def get_aggregation_groups_for_event(
-        self,
-        event_id,
-        event_type=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
-        """Get a list of annotations on the event, grouped by event type and
-        aggregation key, sorted by count.
-
-        This is used e.g. to get the what and how many reactions have happend
-        on an event.
-
-        Args:
-            event_id (str): Fetch events that relate to this event ID.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            limit (int): Only fetch the `limit` groups.
-            direction (str): Whether to fetch the highest count first (`"b"`) or
-                the lowest count first (`"f"`).
-            from_token (AggregationPaginationToken|None): Fetch rows from the
-                given token, or from the start if None.
-            to_token (AggregationPaginationToken|None): Fetch rows up to the
-                given token, or up to the end if None.
-
-
-        Returns:
-            Deferred[PaginationChunk]: List of groups of annotations that
-            match. Each row is a dict with `type`, `key` and `count` fields.
-        """
-
-        where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args = [event_id, RelationTypes.ANNOTATION]
-
-        if event_type:
-            where_clause.append("type = ?")
-            where_args.append(event_type)
-
-        having_clause = generate_pagination_where_clause(
-            direction=direction,
-            column_names=("COUNT(*)", "MAX(stream_ordering)"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
-            engine=self.database_engine,
-        )
-
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        if having_clause:
-            having_clause = "HAVING " + having_clause
-        else:
-            having_clause = ""
-
-        sql = """
-            SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE {where_clause}
-            GROUP BY relation_type, type, aggregation_key
-            {having_clause}
-            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
-            LIMIT ?
-        """.format(
-            where_clause=" AND ".join(where_clause),
-            order=order,
-            having_clause=having_clause,
-        )
-
-        def _get_aggregation_groups_for_event_txn(txn):
-            txn.execute(sql, where_args + [limit + 1])
-
-            next_batch = None
-            events = []
-            for row in txn:
-                events.append({"type": row[0], "key": row[1], "count": row[2]})
-                next_batch = AggregationPaginationToken(row[2], row[3])
-
-            if len(events) <= limit:
-                next_batch = None
-
-            return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
-            )
-
-        return self.runInteraction(
-            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
-        )
-
-    @cachedInlineCallbacks()
-    def get_applicable_edit(self, event_id):
-        """Get the most recent edit (if any) that has happened for the given
-        event.
-
-        Correctly handles checking whether edits were allowed to happen.
-
-        Args:
-            event_id (str): The original event ID
-
-        Returns:
-            Deferred[EventBase|None]: Returns the most recent edit, if any.
-        """
-
-        # We only allow edits for `m.room.message` events that have the same sender
-        # and event type. We can't assert these things during regular event auth so
-        # we have to do the checks post hoc.
-
-        # Fetches latest edit that has the same type and sender as the
-        # original, and is an `m.room.message`.
-        sql = """
-            SELECT edit.event_id FROM events AS edit
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS original ON
-                original.event_id = relates_to_id
-                AND edit.type = original.type
-                AND edit.sender = original.sender
-            WHERE
-                relates_to_id = ?
-                AND relation_type = ?
-                AND edit.type = 'm.room.message'
-            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
-            LIMIT 1
-        """
-
-        def _get_applicable_edit_txn(txn):
-            txn.execute(sql, (event_id, RelationTypes.REPLACE))
-            row = txn.fetchone()
-            if row:
-                return row[0]
-
-        edit_id = yield self.runInteraction(
-            "get_applicable_edit", _get_applicable_edit_txn
-        )
-
-        if not edit_id:
-            return
-
-        edit_event = yield self.get_event(edit_id, allow_none=True)
-        return edit_event
-
-    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
-        """Check if a user has already annotated an event with the same key
-        (e.g. already liked an event).
-
-        Args:
-            parent_id (str): The event being annotated
-            event_type (str): The event type of the annotation
-            aggregation_key (str): The aggregation key of the annotation
-            sender (str): The sender of the annotation
-
-        Returns:
-            Deferred[bool]
-        """
-
-        sql = """
-            SELECT 1 FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE
-                relates_to_id = ?
-                AND relation_type = ?
-                AND type = ?
-                AND sender = ?
-                AND aggregation_key = ?
-            LIMIT 1;
-        """
-
-        def _get_if_user_has_annotated_event(txn):
-            txn.execute(
-                sql,
-                (
-                    parent_id,
-                    RelationTypes.ANNOTATION,
-                    event_type,
-                    sender,
-                    aggregation_key,
-                ),
-            )
-
-            return bool(txn.fetchone())
-
-        return self.runInteraction(
-            "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
-        )
-
-
-class RelationsStore(RelationsWorkerStore):
-    def _handle_event_relations(self, txn, event):
-        """Handles inserting relation data during peristence of events
-
-        Args:
-            txn
-            event (EventBase)
-        """
-        relation = event.content.get("m.relates_to")
-        if not relation:
-            # No relations
-            return
-
-        rel_type = relation.get("rel_type")
-        if rel_type not in (
-            RelationTypes.ANNOTATION,
-            RelationTypes.REFERENCE,
-            RelationTypes.REPLACE,
-        ):
-            # Unknown relation type
-            return
-
-        parent_id = relation.get("event_id")
-        if not parent_id:
-            # Invalid relation
-            return
-
-        aggregation_key = relation.get("key")
-
-        self._simple_insert_txn(
-            txn,
-            table="event_relations",
-            values={
-                "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
-            },
-        )
-
-        txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
-        txn.call_after(
-            self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
-        )
-
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
-    def _handle_redaction(self, txn, redacted_event_id):
-        """Handles receiving a redaction and checking whether we need to remove
-        any redacted relations from the database.
-
-        Args:
-            txn
-            redacted_event_id (str): The event that was redacted.
-        """
-
-        self._simple_delete_txn(
-            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
-        )
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
deleted file mode 100644
index 08e13f3a3b..0000000000
--- a/synapse/storage/room.py
+++ /dev/null
@@ -1,585 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import collections
-import logging
-import re
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.errors import StoreError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.search import SearchStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-
-logger = logging.getLogger(__name__)
-
-
-OpsLevel = collections.namedtuple(
-    "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
-RatelimitOverride = collections.namedtuple(
-    "RatelimitOverride", ("messages_per_second", "burst_count")
-)
-
-
-class RoomWorkerStore(SQLBaseStore):
-    def get_room(self, room_id):
-        """Retrieve a room.
-
-        Args:
-            room_id (str): The ID of the room to retrieve.
-        Returns:
-            A dict containing the room information, or None if the room is unknown.
-        """
-        return self._simple_select_one(
-            table="rooms",
-            keyvalues={"room_id": room_id},
-            retcols=("room_id", "is_public", "creator"),
-            desc="get_room",
-            allow_none=True,
-        )
-
-    def get_public_room_ids(self):
-        return self._simple_select_onecol(
-            table="rooms",
-            keyvalues={"is_public": True},
-            retcol="room_id",
-            desc="get_public_room_ids",
-        )
-
-    @cached(num_args=2, max_entries=100)
-    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
-        """Get pulbic rooms for a particular list, or across all lists.
-
-        Args:
-            stream_id (int)
-            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
-                means the main list, None means all lsits.
-        """
-        return self.runInteraction(
-            "get_public_room_ids_at_stream_id",
-            self.get_public_room_ids_at_stream_id_txn,
-            stream_id,
-            network_tuple=network_tuple,
-        )
-
-    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
-        return {
-            rm
-            for rm, vis in self.get_published_at_stream_id_txn(
-                txn, stream_id, network_tuple=network_tuple
-            ).items()
-            if vis
-        }
-
-    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
-        if network_tuple:
-            # We want to get from a particular list. No aggregation required.
-
-            sql = """
-                SELECT room_id, visibility FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT room_id, max(stream_id) AS stream_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ? %s
-                    GROUP BY room_id
-                ) grouped USING (room_id, stream_id)
-            """
-
-            if network_tuple.appservice_id is not None:
-                txn.execute(
-                    sql % ("AND appservice_id = ? AND network_id = ?",),
-                    (stream_id, network_tuple.appservice_id, network_tuple.network_id),
-                )
-            else:
-                txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
-            return dict(txn)
-        else:
-            # We want to get from all lists, so we need to aggregate the results
-
-            logger.info("Executing full list")
-
-            sql = """
-                SELECT room_id, visibility
-                FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT
-                        room_id, max(stream_id) AS stream_id, appservice_id,
-                        network_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ?
-                    GROUP BY room_id, appservice_id, network_id
-                ) grouped USING (room_id, stream_id)
-            """
-
-            txn.execute(sql, (stream_id,))
-
-            results = {}
-            # A room is visible if its visible on any list.
-            for room_id, visibility in txn:
-                results[room_id] = bool(visibility) or results.get(room_id, False)
-
-            return results
-
-    def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
-        def get_public_room_changes_txn(txn):
-            then_rooms = self.get_public_room_ids_at_stream_id_txn(
-                txn, prev_stream_id, network_tuple
-            )
-
-            now_rooms_dict = self.get_published_at_stream_id_txn(
-                txn, new_stream_id, network_tuple
-            )
-
-            now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
-            now_rooms_not_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if not vis
-            )
-
-            newly_visible = now_rooms_visible - then_rooms
-            newly_unpublished = now_rooms_not_visible & then_rooms
-
-            return newly_visible, newly_unpublished
-
-        return self.runInteraction(
-            "get_public_room_changes", get_public_room_changes_txn
-        )
-
-    @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self._simple_select_one_onecol(
-            table="blocked_rooms",
-            keyvalues={"room_id": room_id},
-            retcol="1",
-            allow_none=True,
-            desc="is_room_blocked",
-        )
-
-    @cachedInlineCallbacks(max_entries=10000)
-    def get_ratelimit_for_user(self, user_id):
-        """Check if there are any overrides for ratelimiting for the given
-        user
-
-        Args:
-            user_id (str)
-
-        Returns:
-            RatelimitOverride if there is an override, else None. If the contents
-            of RatelimitOverride are None or 0 then ratelimitng has been
-            disabled for that user entirely.
-        """
-        row = yield self._simple_select_one(
-            table="ratelimit_override",
-            keyvalues={"user_id": user_id},
-            retcols=("messages_per_second", "burst_count"),
-            allow_none=True,
-            desc="get_ratelimit_for_user",
-        )
-
-        if row:
-            return RatelimitOverride(
-                messages_per_second=row["messages_per_second"],
-                burst_count=row["burst_count"],
-            )
-        else:
-            return None
-
-
-class RoomStore(RoomWorkerStore, SearchStore):
-    @defer.inlineCallbacks
-    def store_room(self, room_id, room_creator_user_id, is_public):
-        """Stores a room.
-
-        Args:
-            room_id (str): The desired room ID, can be None.
-            room_creator_user_id (str): The user ID of the room creator.
-            is_public (bool): True to indicate that this room should appear in
-            public room lists.
-        Raises:
-            StoreError if the room could not be stored.
-        """
-        try:
-
-            def store_room_txn(txn, next_id):
-                self._simple_insert_txn(
-                    txn,
-                    "rooms",
-                    {
-                        "room_id": room_id,
-                        "creator": room_creator_user_id,
-                        "is_public": is_public,
-                    },
-                )
-                if is_public:
-                    self._simple_insert_txn(
-                        txn,
-                        table="public_room_list_stream",
-                        values={
-                            "stream_id": next_id,
-                            "room_id": room_id,
-                            "visibility": is_public,
-                        },
-                    )
-
-            with self._public_room_id_gen.get_next() as next_id:
-                yield self.runInteraction("store_room_txn", store_room_txn, next_id)
-        except Exception as e:
-            logger.error("store_room with room_id=%s failed: %s", room_id, e)
-            raise StoreError(500, "Problem creating room.")
-
-    @defer.inlineCallbacks
-    def set_room_is_public(self, room_id, is_public):
-        def set_room_is_public_txn(txn, next_id):
-            self._simple_update_one_txn(
-                txn,
-                table="rooms",
-                keyvalues={"room_id": room_id},
-                updatevalues={"is_public": is_public},
-            )
-
-            entries = self._simple_select_list_txn(
-                txn,
-                table="public_room_list_stream",
-                keyvalues={
-                    "room_id": room_id,
-                    "appservice_id": None,
-                    "network_id": None,
-                },
-                retcols=("stream_id", "visibility"),
-            )
-
-            entries.sort(key=lambda r: r["stream_id"])
-
-            add_to_stream = True
-            if entries:
-                add_to_stream = bool(entries[-1]["visibility"]) != is_public
-
-            if add_to_stream:
-                self._simple_insert_txn(
-                    txn,
-                    table="public_room_list_stream",
-                    values={
-                        "stream_id": next_id,
-                        "room_id": room_id,
-                        "visibility": is_public,
-                        "appservice_id": None,
-                        "network_id": None,
-                    },
-                )
-
-        with self._public_room_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "set_room_is_public", set_room_is_public_txn, next_id
-            )
-        self.hs.get_notifier().on_new_replication_data()
-
-    @defer.inlineCallbacks
-    def set_room_is_public_appservice(
-        self, room_id, appservice_id, network_id, is_public
-    ):
-        """Edit the appservice/network specific public room list.
-
-        Each appservice can have a number of published room lists associated
-        with them, keyed off of an appservice defined `network_id`, which
-        basically represents a single instance of a bridge to a third party
-        network.
-
-        Args:
-            room_id (str)
-            appservice_id (str)
-            network_id (str)
-            is_public (bool): Whether to publish or unpublish the room from the
-                list.
-        """
-
-        def set_room_is_public_appservice_txn(txn, next_id):
-            if is_public:
-                try:
-                    self._simple_insert_txn(
-                        txn,
-                        table="appservice_room_list",
-                        values={
-                            "appservice_id": appservice_id,
-                            "network_id": network_id,
-                            "room_id": room_id,
-                        },
-                    )
-                except self.database_engine.module.IntegrityError:
-                    # We've already inserted, nothing to do.
-                    return
-            else:
-                self._simple_delete_txn(
-                    txn,
-                    table="appservice_room_list",
-                    keyvalues={
-                        "appservice_id": appservice_id,
-                        "network_id": network_id,
-                        "room_id": room_id,
-                    },
-                )
-
-            entries = self._simple_select_list_txn(
-                txn,
-                table="public_room_list_stream",
-                keyvalues={
-                    "room_id": room_id,
-                    "appservice_id": appservice_id,
-                    "network_id": network_id,
-                },
-                retcols=("stream_id", "visibility"),
-            )
-
-            entries.sort(key=lambda r: r["stream_id"])
-
-            add_to_stream = True
-            if entries:
-                add_to_stream = bool(entries[-1]["visibility"]) != is_public
-
-            if add_to_stream:
-                self._simple_insert_txn(
-                    txn,
-                    table="public_room_list_stream",
-                    values={
-                        "stream_id": next_id,
-                        "room_id": room_id,
-                        "visibility": is_public,
-                        "appservice_id": appservice_id,
-                        "network_id": network_id,
-                    },
-                )
-
-        with self._public_room_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "set_room_is_public_appservice",
-                set_room_is_public_appservice_txn,
-                next_id,
-            )
-        self.hs.get_notifier().on_new_replication_data()
-
-    def get_room_count(self):
-        """Retrieve a list of all rooms
-        """
-
-        def f(txn):
-            sql = "SELECT count(*)  FROM rooms"
-            txn.execute(sql)
-            row = txn.fetchone()
-            return row[0] or 0
-
-        return self.runInteraction("get_rooms", f)
-
-    def _store_room_topic_txn(self, txn, event):
-        if hasattr(event, "content") and "topic" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"]
-            )
-
-    def _store_room_name_txn(self, txn, event):
-        if hasattr(event, "content") and "name" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.name", event.content["name"]
-            )
-
-    def _store_room_message_txn(self, txn, event):
-        if hasattr(event, "content") and "body" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.body", event.content["body"]
-            )
-
-    def add_event_report(
-        self, room_id, event_id, user_id, reason, content, received_ts
-    ):
-        next_id = self._event_reports_id_gen.get_next()
-        return self._simple_insert(
-            table="event_reports",
-            values={
-                "id": next_id,
-                "received_ts": received_ts,
-                "room_id": room_id,
-                "event_id": event_id,
-                "user_id": user_id,
-                "reason": reason,
-                "content": json.dumps(content),
-            },
-            desc="add_event_report",
-        )
-
-    def get_current_public_room_stream_id(self):
-        return self._public_room_id_gen.get_current_token()
-
-    def get_all_new_public_rooms(self, prev_id, current_id, limit):
-        def get_all_new_public_rooms(txn):
-            sql = """
-                SELECT stream_id, room_id, visibility, appservice_id, network_id
-                FROM public_room_list_stream
-                WHERE stream_id > ? AND stream_id <= ?
-                ORDER BY stream_id ASC
-                LIMIT ?
-            """
-
-            txn.execute(sql, (prev_id, current_id, limit))
-            return txn.fetchall()
-
-        if prev_id == current_id:
-            return defer.succeed([])
-
-        return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
-
-    @defer.inlineCallbacks
-    def block_room(self, room_id, user_id):
-        """Marks the room as blocked. Can be called multiple times.
-
-        Args:
-            room_id (str): Room to block
-            user_id (str): Who blocked it
-
-        Returns:
-            Deferred
-        """
-        yield self._simple_upsert(
-            table="blocked_rooms",
-            keyvalues={"room_id": room_id},
-            values={},
-            insertion_values={"user_id": user_id},
-            desc="block_room",
-        )
-        yield self.runInteraction(
-            "block_room_invalidation",
-            self._invalidate_cache_and_stream,
-            self.is_room_blocked,
-            (room_id,),
-        )
-
-    def get_media_mxcs_in_room(self, room_id):
-        """Retrieves all the local and remote media MXC URIs in a given room
-
-        Args:
-            room_id (str)
-
-        Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
-        """
-
-        def _get_media_mxcs_in_room_txn(txn):
-            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
-            local_media_mxcs = []
-            remote_media_mxcs = []
-
-            # Convert the IDs to MXC URIs
-            for media_id in local_mxcs:
-                local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
-            for hostname, media_id in remote_mxcs:
-                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
-
-            return local_media_mxcs, remote_media_mxcs
-
-        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
-
-    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
-        """For a room loops through all events with media and quarantines
-        the associated media
-        """
-
-        def _quarantine_media_in_room_txn(txn):
-            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
-            total_media_quarantined = 0
-
-            # Now update all the tables to set the quarantined_by flag
-
-            txn.executemany(
-                """
-                UPDATE local_media_repository
-                SET quarantined_by = ?
-                WHERE media_id = ?
-            """,
-                ((quarantined_by, media_id) for media_id in local_mxcs),
-            )
-
-            txn.executemany(
-                """
-                    UPDATE remote_media_cache
-                    SET quarantined_by = ?
-                    WHERE media_origin = ? AND media_id = ?
-                """,
-                (
-                    (quarantined_by, origin, media_id)
-                    for origin, media_id in remote_mxcs
-                ),
-            )
-
-            total_media_quarantined += len(local_mxcs)
-            total_media_quarantined += len(remote_mxcs)
-
-            return total_media_quarantined
-
-        return self.runInteraction(
-            "quarantine_media_in_room", _quarantine_media_in_room_txn
-        )
-
-    def _get_media_mxcs_in_room_txn(self, txn, room_id):
-        """Retrieves all the local and remote media MXC URIs in a given room
-
-        Args:
-            txn (cursor)
-            room_id (str)
-
-        Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
-        """
-        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
-        next_token = self.get_current_events_token() + 1
-        local_media_mxcs = []
-        remote_media_mxcs = []
-
-        while next_token:
-            sql = """
-                SELECT stream_ordering, json FROM events
-                JOIN event_json USING (room_id, event_id)
-                WHERE room_id = ?
-                    AND stream_ordering < ?
-                    AND contains_url = ? AND outlier = ?
-                ORDER BY stream_ordering DESC
-                LIMIT ?
-            """
-            txn.execute(sql, (room_id, next_token, True, False, 100))
-
-            next_token = None
-            for stream_ordering, content_json in txn:
-                next_token = stream_ordering
-                event_json = json.loads(content_json)
-                content = event_json["content"]
-                content_url = content.get("url")
-                thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
-                for url in (content_url, thumbnail_url):
-                    if not url:
-                        continue
-                    matches = mxc_re.match(url)
-                    if matches:
-                        hostname = matches.group(1)
-                        media_id = matches.group(2)
-                        if hostname == self.hs.hostname:
-                            local_media_mxcs.append(media_id)
-                        else:
-                            remote_media_mxcs.append((hostname, media_id))
-
-        return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 4df8ebdacd..8c4a83a840 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,24 +17,6 @@
 import logging
 from collections import namedtuple
 
-from six import iteritems, itervalues
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction
-from synapse.storage.engines import Sqlite3Engine
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.types import get_domain_from_id
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.util.stringutils import to_ascii
-
 logger = logging.getLogger(__name__)
 
 
@@ -55,1057 +37,3 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
 # a given membership type, suitable for use in calculating heroes for a room.
 # "count" points to the total numberr of users of a given membership type.
 MemberSummary = namedtuple("MemberSummary", ("members", "count"))
-
-_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
-
-
-class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
-
-        # Is the current_state_events.membership up to date? Or is the
-        # background update still running?
-        self._current_state_events_membership_up_to_date = False
-
-        txn = LoggingTransaction(
-            db_conn.cursor(),
-            name="_check_safe_current_state_events_membership_updated",
-            database_engine=self.database_engine,
-        )
-        self._check_safe_current_state_events_membership_updated_txn(txn)
-        txn.close()
-
-        if self.hs.config.metrics_flags.known_servers:
-            self._known_servers_count = 1
-            self.hs.get_clock().looping_call(
-                run_as_background_process,
-                60 * 1000,
-                "_count_known_servers",
-                self._count_known_servers,
-            )
-            self.hs.get_clock().call_later(
-                1000,
-                run_as_background_process,
-                "_count_known_servers",
-                self._count_known_servers,
-            )
-            LaterGauge(
-                "synapse_federation_known_servers",
-                "",
-                [],
-                lambda: self._known_servers_count,
-            )
-
-    @defer.inlineCallbacks
-    def _count_known_servers(self):
-        """
-        Count the servers that this server knows about.
-
-        The statistic is stored on the class for the
-        `synapse_federation_known_servers` LaterGauge to collect.
-        """
-
-        def _transact(txn):
-            if isinstance(self.database_engine, Sqlite3Engine):
-                query = """
-                    SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
-                    FROM (
-                        SELECT rm.user_id as user_id, instr(rm.user_id, ':')
-                            AS pos FROM room_memberships as rm
-                        INNER JOIN current_state_events as c ON rm.event_id = c.event_id
-                        WHERE c.type = 'm.room.member'
-                    ) as out
-                """
-            else:
-                query = """
-                    SELECT COUNT(DISTINCT split_part(state_key, ':', 2))
-                    FROM current_state_events
-                    WHERE type = 'm.room.member' AND membership = 'join';
-                """
-            txn.execute(query)
-            return list(txn)[0][0]
-
-        count = yield self.runInteraction("get_known_servers", _transact)
-
-        # We always know about ourselves, even if we have nothing in
-        # room_memberships (for example, the server is new).
-        self._known_servers_count = max([count, 1])
-        return self._known_servers_count
-
-    def _check_safe_current_state_events_membership_updated_txn(self, txn):
-        """Checks if it is safe to assume the new current_state_events
-        membership column is up to date
-        """
-
-        pending_update = self._simple_select_one_txn(
-            txn,
-            table="background_updates",
-            keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
-            retcols=["update_name"],
-            allow_none=True,
-        )
-
-        self._current_state_events_membership_up_to_date = not pending_update
-
-        # If the update is still running, reschedule to run.
-        if pending_update:
-            self._clock.call_later(
-                15.0,
-                run_as_background_process,
-                "_check_safe_current_state_events_membership_updated",
-                self.runInteraction,
-                "_check_safe_current_state_events_membership_updated",
-                self._check_safe_current_state_events_membership_updated_txn,
-            )
-
-    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
-    def get_hosts_in_room(self, room_id, cache_context):
-        """Returns the set of all hosts currently in the room
-        """
-        user_ids = yield self.get_users_in_room(
-            room_id, on_invalidate=cache_context.invalidate
-        )
-        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
-        return hosts
-
-    @cached(max_entries=100000, iterable=True)
-    def get_users_in_room(self, room_id):
-        return self.runInteraction(
-            "get_users_in_room", self.get_users_in_room_txn, room_id
-        )
-
-    def get_users_in_room_txn(self, txn, room_id):
-        # If we can assume current_state_events.membership is up to date
-        # then we can avoid a join, which is a Very Good Thing given how
-        # frequently this function gets called.
-        if self._current_state_events_membership_up_to_date:
-            sql = """
-                SELECT state_key FROM current_state_events
-                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
-            """
-        else:
-            sql = """
-                SELECT state_key FROM room_memberships as m
-                INNER JOIN current_state_events as c
-                ON m.event_id = c.event_id
-                AND m.room_id = c.room_id
-                AND m.user_id = c.state_key
-                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
-            """
-
-        txn.execute(sql, (room_id, Membership.JOIN))
-        return [to_ascii(r[0]) for r in txn]
-
-    @cached(max_entries=100000)
-    def get_room_summary(self, room_id):
-        """ Get the details of a room roughly suitable for use by the room
-        summary extension to /sync. Useful when lazy loading room members.
-        Args:
-            room_id (str): The room ID to query
-        Returns:
-            Deferred[dict[str, MemberSummary]:
-                dict of membership states, pointing to a MemberSummary named tuple.
-        """
-
-        def _get_room_summary_txn(txn):
-            # first get counts.
-            # We do this all in one transaction to keep the cache small.
-            # FIXME: get rid of this when we have room_stats
-
-            # If we can assume current_state_events.membership is up to date
-            # then we can avoid a join, which is a Very Good Thing given how
-            # frequently this function gets called.
-            if self._current_state_events_membership_up_to_date:
-                # Note, rejected events will have a null membership field, so
-                # we we manually filter them out.
-                sql = """
-                    SELECT count(*), membership FROM current_state_events
-                    WHERE type = 'm.room.member' AND room_id = ?
-                        AND membership IS NOT NULL
-                    GROUP BY membership
-                """
-            else:
-                sql = """
-                    SELECT count(*), m.membership FROM room_memberships as m
-                    INNER JOIN current_state_events as c
-                    ON m.event_id = c.event_id
-                    AND m.room_id = c.room_id
-                    AND m.user_id = c.state_key
-                    WHERE c.type = 'm.room.member' AND c.room_id = ?
-                    GROUP BY m.membership
-                """
-
-            txn.execute(sql, (room_id,))
-            res = {}
-            for count, membership in txn:
-                summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
-
-            # we order by membership and then fairly arbitrarily by event_id so
-            # heroes are consistent
-            if self._current_state_events_membership_up_to_date:
-                # Note, rejected events will have a null membership field, so
-                # we we manually filter them out.
-                sql = """
-                    SELECT state_key, membership, event_id
-                    FROM current_state_events
-                    WHERE type = 'm.room.member' AND room_id = ?
-                        AND membership IS NOT NULL
-                    ORDER BY
-                        CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                        event_id ASC
-                    LIMIT ?
-                """
-            else:
-                sql = """
-                    SELECT c.state_key, m.membership, c.event_id
-                    FROM room_memberships as m
-                    INNER JOIN current_state_events as c USING (room_id, event_id)
-                    WHERE c.type = 'm.room.member' AND c.room_id = ?
-                    ORDER BY
-                        CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                        c.event_id ASC
-                    LIMIT ?
-                """
-
-            # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
-            txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
-            for user_id, membership, event_id in txn:
-                summary = res[to_ascii(membership)]
-                # we will always have a summary for this membership type at this
-                # point given the summary currently contains the counts.
-                members = summary.members
-                members.append((to_ascii(user_id), to_ascii(event_id)))
-
-            return res
-
-        return self.runInteraction("get_room_summary", _get_room_summary_txn)
-
-    def _get_user_counts_in_room_txn(self, txn, room_id):
-        """
-        Get the user count in a room by membership.
-
-        Args:
-            room_id (str)
-            membership (Membership)
-
-        Returns:
-            Deferred[int]
-        """
-        sql = """
-        SELECT m.membership, count(*) FROM room_memberships as m
-            INNER JOIN current_state_events as c USING(event_id)
-            WHERE c.type = 'm.room.member' AND c.room_id = ?
-            GROUP BY m.membership
-        """
-
-        txn.execute(sql, (room_id,))
-        return {row[0]: row[1] for row in txn}
-
-    @cached()
-    def get_invited_rooms_for_user(self, user_id):
-        """ Get all the rooms the user is invited to
-        Args:
-            user_id (str): The user ID.
-        Returns:
-            A deferred list of RoomsForUser.
-        """
-
-        return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
-
-    @defer.inlineCallbacks
-    def get_invite_for_user_in_room(self, user_id, room_id):
-        """Gets the invite for the given user and room
-
-        Args:
-            user_id (str)
-            room_id (str)
-
-        Returns:
-            Deferred: Resolves to either a RoomsForUser or None if no invite was
-                found.
-        """
-        invites = yield self.get_invited_rooms_for_user(user_id)
-        for invite in invites:
-            if invite.room_id == room_id:
-                return invite
-        return None
-
-    @defer.inlineCallbacks
-    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
-        """ Get all the rooms for this user where the membership for this user
-        matches one in the membership list.
-
-        Filters out forgotten rooms.
-
-        Args:
-            user_id (str): The user ID.
-            membership_list (list): A list of synapse.api.constants.Membership
-            values which the user must be in.
-
-        Returns:
-            Deferred[list[RoomsForUser]]
-        """
-        if not membership_list:
-            return defer.succeed(None)
-
-        rooms = yield self.runInteraction(
-            "get_rooms_for_user_where_membership_is",
-            self._get_rooms_for_user_where_membership_is_txn,
-            user_id,
-            membership_list,
-        )
-
-        # Now we filter out forgotten rooms
-        forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
-        return [room for room in rooms if room.room_id not in forgotten_rooms]
-
-    def _get_rooms_for_user_where_membership_is_txn(
-        self, txn, user_id, membership_list
-    ):
-
-        do_invite = Membership.INVITE in membership_list
-        membership_list = [m for m in membership_list if m != Membership.INVITE]
-
-        results = []
-        if membership_list:
-            if self._current_state_events_membership_up_to_date:
-                sql = """
-                    SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND c.membership IN (%s)
-                """ % (
-                    ",".join("?" * len(membership_list))
-                )
-            else:
-                sql = """
-                    SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN room_memberships AS m USING (room_id, event_id)
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND m.membership IN (%s)
-                """ % (
-                    ",".join("?" * len(membership_list))
-                )
-
-            txn.execute(sql, (user_id, *membership_list))
-            results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
-
-        if do_invite:
-            sql = (
-                "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
-                " FROM local_invites as i"
-                " INNER JOIN events as e USING (event_id)"
-                " WHERE invitee = ? AND locally_rejected is NULL"
-                " AND replaced_by is NULL"
-            )
-
-            txn.execute(sql, (user_id,))
-            results.extend(
-                RoomsForUser(
-                    room_id=r["room_id"],
-                    sender=r["inviter"],
-                    event_id=r["event_id"],
-                    stream_ordering=r["stream_ordering"],
-                    membership=Membership.INVITE,
-                )
-                for r in self.cursor_to_dict(txn)
-            )
-
-        return results
-
-    @cachedInlineCallbacks(max_entries=500000, iterable=True)
-    def get_rooms_for_user_with_stream_ordering(self, user_id):
-        """Returns a set of room_ids the user is currently joined to
-
-        Args:
-            user_id (str)
-
-        Returns:
-            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
-            the rooms the user is in currently, along with the stream ordering
-            of the most recent join for that user and room.
-        """
-        rooms = yield self.get_rooms_for_user_where_membership_is(
-            user_id, membership_list=[Membership.JOIN]
-        )
-        return frozenset(
-            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
-            for r in rooms
-        )
-
-    @defer.inlineCallbacks
-    def get_rooms_for_user(self, user_id, on_invalidate=None):
-        """Returns a set of room_ids the user is currently joined to
-        """
-        rooms = yield self.get_rooms_for_user_with_stream_ordering(
-            user_id, on_invalidate=on_invalidate
-        )
-        return frozenset(r.room_id for r in rooms)
-
-    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
-    def get_users_who_share_room_with_user(self, user_id, cache_context):
-        """Returns the set of users who share a room with `user_id`
-        """
-        room_ids = yield self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
-
-        user_who_share_room = set()
-        for room_id in room_ids:
-            user_ids = yield self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate
-            )
-            user_who_share_room.update(user_ids)
-
-        return user_who_share_room
-
-    @defer.inlineCallbacks
-    def get_joined_users_from_context(self, event, context):
-        state_group = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = yield context.get_current_state_ids(self)
-        result = yield self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
-        return result
-
-    def get_joined_users_from_state(self, room_id, state_entry):
-        state_group = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        return self._get_joined_users_from_context(
-            room_id, state_group, state_entry.state, context=state_entry
-        )
-
-    @cachedInlineCallbacks(
-        num_args=2, cache_context=True, iterable=True, max_entries=100000
-    )
-    def _get_joined_users_from_context(
-        self,
-        room_id,
-        state_group,
-        current_state_ids,
-        cache_context,
-        event=None,
-        context=None,
-    ):
-        # We don't use `state_group`, it's there so that we can cache based
-        # on it. However, it's important that it's never None, since two current_states
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        users_in_room = {}
-        member_event_ids = [
-            e_id
-            for key, e_id in iteritems(current_state_ids)
-            if key[0] == EventTypes.Member
-        ]
-
-        if context is not None:
-            # If we have a context with a delta from a previous state group,
-            # check if we also have the result from the previous group in cache.
-            # If we do then we can reuse that result and simply update it with
-            # any membership changes in `delta_ids`
-            if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get(
-                    (room_id, context.prev_group), None
-                )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
-                    member_event_ids = [
-                        e_id
-                        for key, e_id in iteritems(context.delta_ids)
-                        if key[0] == EventTypes.Member
-                    ]
-                    for etype, state_key in context.delta_ids:
-                        users_in_room.pop(state_key, None)
-
-        # We check if we have any of the member event ids in the event cache
-        # before we ask the DB
-
-        # We don't update the event cache hit ratio as it completely throws off
-        # the hit ratio counts. After all, we don't populate the cache if we
-        # miss it here
-        event_map = self._get_events_from_cache(
-            member_event_ids, allow_rejected=False, update_metrics=False
-        )
-
-        missing_member_event_ids = []
-        for event_id in member_event_ids:
-            ev_entry = event_map.get(event_id)
-            if ev_entry:
-                if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
-                        display_name=to_ascii(
-                            ev_entry.event.content.get("displayname", None)
-                        ),
-                        avatar_url=to_ascii(
-                            ev_entry.event.content.get("avatar_url", None)
-                        ),
-                    )
-            else:
-                missing_member_event_ids.append(event_id)
-
-        if missing_member_event_ids:
-            rows = yield self._simple_select_many_batch(
-                table="room_memberships",
-                column="event_id",
-                iterable=missing_member_event_ids,
-                retcols=("user_id", "display_name", "avatar_url"),
-                keyvalues={"membership": Membership.JOIN},
-                batch_size=500,
-                desc="_get_joined_users_from_context",
-            )
-
-            users_in_room.update(
-                {
-                    to_ascii(row["user_id"]): ProfileInfo(
-                        avatar_url=to_ascii(row["avatar_url"]),
-                        display_name=to_ascii(row["display_name"]),
-                    )
-                    for row in rows
-                }
-            )
-
-        if event is not None and event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                if event.event_id in member_event_ids:
-                    users_in_room[to_ascii(event.state_key)] = ProfileInfo(
-                        display_name=to_ascii(event.content.get("displayname", None)),
-                        avatar_url=to_ascii(event.content.get("avatar_url", None)),
-                    )
-
-        return users_in_room
-
-    @cachedInlineCallbacks(max_entries=10000)
-    def is_host_joined(self, room_id, host):
-        if "%" in host or "_" in host:
-            raise Exception("Invalid host name")
-
-        sql = """
-            SELECT state_key FROM current_state_events AS c
-            INNER JOIN room_memberships AS m USING (event_id)
-            WHERE m.membership = 'join'
-                AND type = 'm.room.member'
-                AND c.room_id = ?
-                AND state_key LIKE ?
-            LIMIT 1
-        """
-
-        # We do need to be careful to ensure that host doesn't have any wild cards
-        # in it, but we checked above for known ones and we'll check below that
-        # the returned user actually has the correct domain.
-        like_clause = "%:" + host
-
-        rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
-
-        if not rows:
-            return False
-
-        user_id = rows[0][0]
-        if get_domain_from_id(user_id) != host:
-            # This can only happen if the host name has something funky in it
-            raise Exception("Invalid host name")
-
-        return True
-
-    @cachedInlineCallbacks()
-    def was_host_joined(self, room_id, host):
-        """Check whether the server is or ever was in the room.
-
-        Args:
-            room_id (str)
-            host (str)
-
-        Returns:
-            Deferred: Resolves to True if the host is/was in the room, otherwise
-            False.
-        """
-        if "%" in host or "_" in host:
-            raise Exception("Invalid host name")
-
-        sql = """
-            SELECT user_id FROM room_memberships
-            WHERE room_id = ?
-                AND user_id LIKE ?
-                AND membership = 'join'
-            LIMIT 1
-        """
-
-        # We do need to be careful to ensure that host doesn't have any wild cards
-        # in it, but we checked above for known ones and we'll check below that
-        # the returned user actually has the correct domain.
-        like_clause = "%:" + host
-
-        rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
-
-        if not rows:
-            return False
-
-        user_id = rows[0][0]
-        if get_domain_from_id(user_id) != host:
-            # This can only happen if the host name has something funky in it
-            raise Exception("Invalid host name")
-
-        return True
-
-    def get_joined_hosts(self, room_id, state_entry):
-        state_group = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        return self._get_joined_hosts(
-            room_id, state_group, state_entry.state, state_entry=state_entry
-        )
-
-    @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
-    # @defer.inlineCallbacks
-    def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
-        # We don't use `state_group`, its there so that we can cache based
-        # on it. However, its important that its never None, since two current_state's
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        cache = self._get_joined_hosts_cache(room_id)
-        joined_hosts = yield cache.get_destinations(state_entry)
-
-        return joined_hosts
-
-    @cached(max_entries=10000)
-    def _get_joined_hosts_cache(self, room_id):
-        return _JoinedHostsCache(self, room_id)
-
-    @cachedInlineCallbacks(num_args=2)
-    def did_forget(self, user_id, room_id):
-        """Returns whether user_id has elected to discard history for room_id.
-
-        Returns False if they have since re-joined."""
-
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  COUNT(*)"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  forgotten = 0"
-            )
-            txn.execute(sql, (user_id, room_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-
-        count = yield self.runInteraction("did_forget_membership", f)
-        return count == 0
-
-    @cached()
-    def get_forgotten_rooms_for_user(self, user_id):
-        """Gets all rooms the user has forgotten.
-
-        Args:
-            user_id (str)
-
-        Returns:
-            Deferred[set[str]]
-        """
-
-        def _get_forgotten_rooms_for_user_txn(txn):
-            # This is a slightly convoluted query that first looks up all rooms
-            # that the user has forgotten in the past, then rechecks that list
-            # to see if any have subsequently been updated. This is done so that
-            # we can use a partial index on `forgotten = 1` on the assumption
-            # that few users will actually forget many rooms.
-            #
-            # Note that a room is considered "forgotten" if *all* membership
-            # events for that user and room have the forgotten field set (as
-            # when a user forgets a room we update all rows for that user and
-            # room, not just the current one).
-            sql = """
-                SELECT room_id, (
-                    SELECT count(*) FROM room_memberships
-                    WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
-                ) AS count
-                FROM room_memberships AS m
-                WHERE user_id = ? AND forgotten = 1
-                GROUP BY room_id, user_id;
-            """
-            txn.execute(sql, (user_id,))
-            return set(row[0] for row in txn if row[1] == 0)
-
-        return self.runInteraction(
-            "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
-        )
-
-    @defer.inlineCallbacks
-    def get_rooms_user_has_been_in(self, user_id):
-        """Get all rooms that the user has ever been in.
-
-        Args:
-            user_id (str)
-
-        Returns:
-            Deferred[set[str]]: Set of room IDs.
-        """
-
-        room_ids = yield self._simple_select_onecol(
-            table="room_memberships",
-            keyvalues={"membership": Membership.JOIN, "user_id": user_id},
-            retcol="room_id",
-            desc="get_rooms_user_has_been_in",
-        )
-
-        return set(room_ids)
-
-
-class RoomMemberStore(RoomMemberWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
-        )
-        self.register_background_update_handler(
-            _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
-            self._background_current_state_membership,
-        )
-        self.register_background_index_update(
-            "room_membership_forgotten_idx",
-            index_name="room_memberships_user_room_forgotten",
-            table="room_memberships",
-            columns=["user_id", "room_id"],
-            where_clause="forgotten = 1",
-        )
-
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ],
-        )
-
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key,
-                event.internal_metadata.stream_ordering,
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
-            )
-
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened. If the event is an
-            # outlier it is only current if its an "out of band membership",
-            # like a remote invite or a rejection of a remote invite.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_out_of_band_membership()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self._simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        },
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(
-                        sql,
-                        (
-                            event.internal_metadata.stream_ordering,
-                            event.event_id,
-                            event.room_id,
-                            event.state_key,
-                        ),
-                    )
-
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
-        def f(txn, stream_ordering):
-            txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
-    def forget(self, user_id, room_id):
-        """Indicate that user_id wishes to discard history for room_id."""
-
-        def f(txn):
-            sql = (
-                "UPDATE"
-                "  room_memberships"
-                " SET"
-                "  forgotten = 1"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id))
-
-            self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
-            self._invalidate_cache_and_stream(
-                txn, self.get_forgotten_rooms_for_user, (user_id,)
-            )
-
-        return self.runInteraction("forget_membership", f)
-
-    @defer.inlineCallbacks
-    def _background_add_membership_profile(self, progress, batch_size):
-        target_min_stream_id = progress.get(
-            "target_min_stream_id_inclusive", self._min_stream_order_on_start
-        )
-        max_stream_id = progress.get(
-            "max_stream_id_exclusive", self._stream_order_on_start + 1
-        )
-
-        INSERT_CLUMP_SIZE = 1000
-
-        def add_membership_profile_txn(txn):
-            sql = """
-                SELECT stream_ordering, event_id, events.room_id, event_json.json
-                FROM events
-                INNER JOIN event_json USING (event_id)
-                INNER JOIN room_memberships USING (event_id)
-                WHERE ? <= stream_ordering AND stream_ordering < ?
-                AND type = 'm.room.member'
-                ORDER BY stream_ordering DESC
-                LIMIT ?
-            """
-
-            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
-            rows = self.cursor_to_dict(txn)
-            if not rows:
-                return 0
-
-            min_stream_id = rows[-1]["stream_ordering"]
-
-            to_update = []
-            for row in rows:
-                event_id = row["event_id"]
-                room_id = row["room_id"]
-                try:
-                    event_json = json.loads(row["json"])
-                    content = event_json["content"]
-                except Exception:
-                    continue
-
-                display_name = content.get("displayname", None)
-                avatar_url = content.get("avatar_url", None)
-
-                if display_name or avatar_url:
-                    to_update.append((display_name, avatar_url, event_id, room_id))
-
-            to_update_sql = """
-                UPDATE room_memberships SET display_name = ?, avatar_url = ?
-                WHERE event_id = ? AND room_id = ?
-            """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
-
-            progress = {
-                "target_min_stream_id_inclusive": target_min_stream_id,
-                "max_stream_id_exclusive": min_stream_id,
-            }
-
-            self._background_update_progress_txn(
-                txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
-            )
-
-            return len(rows)
-
-        result = yield self.runInteraction(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
-        )
-
-        if not result:
-            yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
-
-        return result
-
-    @defer.inlineCallbacks
-    def _background_current_state_membership(self, progress, batch_size):
-        """Update the new membership column on current_state_events.
-
-        This works by iterating over all rooms in alphebetical order.
-        """
-
-        def _background_current_state_membership_txn(txn, last_processed_room):
-            processed = 0
-            while processed < batch_size:
-                txn.execute(
-                    """
-                        SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
-                    """,
-                    (last_processed_room,),
-                )
-                row = txn.fetchone()
-                if not row or not row[0]:
-                    return processed, True
-
-                next_room, = row
-
-                sql = """
-                    UPDATE current_state_events
-                    SET membership = (
-                        SELECT membership FROM room_memberships
-                        WHERE event_id = current_state_events.event_id
-                    )
-                    WHERE room_id = ?
-                """
-                txn.execute(sql, (next_room,))
-                processed += txn.rowcount
-
-                last_processed_room = next_room
-
-            self._background_update_progress_txn(
-                txn,
-                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
-                {"last_processed_room": last_processed_room},
-            )
-
-            return processed, False
-
-        # If we haven't got a last processed room then just use the empty
-        # string, which will compare before all room IDs correctly.
-        last_processed_room = progress.get("last_processed_room", "")
-
-        row_count, finished = yield self.runInteraction(
-            "_background_current_state_membership_update",
-            _background_current_state_membership_txn,
-            last_processed_room,
-        )
-
-        if finished:
-            yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
-
-        return row_count
-
-
-class _JoinedHostsCache(object):
-    """Cache for joined hosts in a room that is optimised to handle updates
-    via state deltas.
-    """
-
-    def __init__(self, store, room_id):
-        self.store = store
-        self.room_id = room_id
-
-        self.hosts_to_joined_users = {}
-
-        self.state_group = object()
-
-        self.linearizer = Linearizer("_JoinedHostsCache")
-
-        self._len = 0
-
-    @defer.inlineCallbacks
-    def get_destinations(self, state_entry):
-        """Get set of destinations for a state entry
-
-        Args:
-            state_entry(synapse.state._StateCacheEntry)
-        """
-        if state_entry.state_group == self.state_group:
-            return frozenset(self.hosts_to_joined_users)
-
-        with (yield self.linearizer.queue(())):
-            if state_entry.state_group == self.state_group:
-                pass
-            elif state_entry.prev_group == self.state_group:
-                for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
-                    if typ != EventTypes.Member:
-                        continue
-
-                    host = intern_string(get_domain_from_id(state_key))
-                    user_id = state_key
-                    known_joins = self.hosts_to_joined_users.setdefault(host, set())
-
-                    event = yield self.store.get_event(event_id)
-                    if event.membership == Membership.JOIN:
-                        known_joins.add(user_id)
-                    else:
-                        known_joins.discard(user_id)
-
-                        if not known_joins:
-                            self.hosts_to_joined_users.pop(host, None)
-            else:
-                joined_users = yield self.store.get_joined_users_from_state(
-                    self.room_id, state_entry
-                )
-
-                self.hosts_to_joined_users = {}
-                for user_id in joined_users:
-                    host = intern_string(get_domain_from_id(user_id))
-                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
-            if state_entry.state_group:
-                self.state_group = state_entry.state_group
-            else:
-                self.state_group = object()
-            self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
-        return frozenset(self.hosts_to_joined_users)
-
-    def __len__(self):
-        return self._len
diff --git a/synapse/storage/schema/delta/35/00background_updates_add_col.sql b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
new file mode 100644
index 0000000000..c2d2a4f836
--- /dev/null
+++ b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
diff --git a/synapse/storage/schema/delta/58/00background_update_ordering.sql b/synapse/storage/schema/delta/58/00background_update_ordering.sql
new file mode 100644
index 0000000000..02dae587cc
--- /dev/null
+++ b/synapse/storage/schema/delta/58/00background_update_ordering.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* add an "ordering" column to background_updates, which can be used to sort them
+   to achieve some level of consistency. */
+
+ALTER TABLE background_updates ADD COLUMN ordering INT NOT NULL DEFAULT 0;
diff --git a/synapse/storage/schema/full_schemas/54/full.sql b/synapse/storage/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..1005880466
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/54/full.sql
@@ -0,0 +1,8 @@
+
+
+CREATE TABLE background_updates (
+    update_name text NOT NULL,
+    progress_json text NOT NULL,
+    depends_on text,
+    CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
+);
diff --git a/synapse/storage/schema/full_schemas/README.txt b/synapse/storage/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 1980a87108..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,43 +14,21 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
+from typing import Iterable, List, TypeVar
 
 from six import iteritems, itervalues
-from six.moves import range
 
 import attr
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.util.caches import get_cache_factor_for, intern_string
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.dictionary_cache import DictionaryCache
-from synapse.util.stringutils import to_ascii
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
-
-MAX_STATE_DELTA_HOPS = 100
-
-
-class _GetStateGroupDelta(
-    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
-    """Return type of get_state_group_delta that implements __len__, which lets
-    us use the itrable flag when caching
-    """
-
-    __slots__ = []
-
-    def __len__(self):
-        return len(self.delta_ids) if self.delta_ids else 0
+# Used for generic functions below
+T = TypeVar("T")
 
 
 @attr.s(slots=True)
@@ -260,14 +238,14 @@ class StateFilter(object):
 
         return len(self.concrete_types())
 
-    def filter_state(self, state_dict):
+    def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
         """Returns the state filtered with by this StateFilter
 
         Args:
-            state (dict[tuple[str, str], Any]): The state map to filter
+            state: The state map to filter
 
         Returns:
-            dict[tuple[str, str], Any]: The filtered state map
+            The filtered state map
         """
         if self.is_full():
             return dict(state_dict)
@@ -353,252 +331,23 @@ class StateFilter(object):
         return member_filter, non_member_filter
 
 
-# this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
-    """The parts of StateGroupStore that can be called from workers.
+class StateGroupStorage(object):
+    """High level interface to fetching state for event.
     """
 
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
-    def __init__(self, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
-
-        # Originally the state store used a single DictionaryCache to cache the
-        # event IDs for the state types in a given state group to avoid hammering
-        # on the state_group* tables.
-        #
-        # The point of using a DictionaryCache is that it can cache a subset
-        # of the state events for a given state group (i.e. a subset of the keys for a
-        # given dict which is an entry in the cache for a given state group ID).
-        #
-        # However, this poses problems when performing complicated queries
-        # on the store - for instance: "give me all the state for this group, but
-        # limit members to this subset of users", as DictionaryCache's API isn't
-        # rich enough to say "please cache any of these fields, apart from this subset".
-        # This is problematic when lazy loading members, which requires this behaviour,
-        # as without it the cache has no choice but to speculatively load all
-        # state events for the group, which negates the efficiency being sought.
-        #
-        # Rather than overcomplicating DictionaryCache's API, we instead split the
-        # state_group_cache into two halves - one for tracking non-member events,
-        # and the other for tracking member_events.  This means that lazy loading
-        # queries can be made in a cache-friendly manner by querying both caches
-        # separately and then merging the result.  So for the example above, you
-        # would query the members cache for a specific subset of state keys
-        # (which DictionaryCache will handle efficiently and fine) and the non-members
-        # cache for all state (which DictionaryCache will similarly handle fine)
-        # and then just merge the results together.
-        #
-        # We size the non-members cache to be smaller than the members cache as the
-        # vast majority of state in Matrix (today) is member events.
-
-        self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*",
-            # TODO: this hasn't been tuned yet
-            50000 * get_cache_factor_for("stateGroupCache"),
-        )
-        self._state_group_members_cache = DictionaryCache(
-            "*stateGroupMembersCache*",
-            500000 * get_cache_factor_for("stateGroupMembersCache"),
-        )
-
-    @defer.inlineCallbacks
-    def get_room_version(self, room_id):
-        """Get the room_version of a given room
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[str]
-
-        Raises:
-            NotFoundError if the room is unknown
-        """
-        # for now we do this by looking at the create event. We may want to cache this
-        # more intelligently in future.
-
-        # Retrieve the room's create event
-        create_event = yield self.get_create_event_for_room(room_id)
-        return create_event.content.get("room_version", "1")
-
-    @defer.inlineCallbacks
-    def get_room_predecessor(self, room_id):
-        """Get the predecessor room of an upgraded room if one exists.
-        Otherwise return None.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[unicode|None]: predecessor room id
-
-        Raises:
-            NotFoundError if the room is unknown
-        """
-        # Retrieve the room's create event
-        create_event = yield self.get_create_event_for_room(room_id)
-
-        # Return predecessor if present
-        return create_event.content.get("predecessor", None)
-
-    @defer.inlineCallbacks
-    def get_create_event_for_room(self, room_id):
-        """Get the create state event for a room.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[EventBase]: The room creation event.
-
-        Raises:
-            NotFoundError if the room is unknown
-        """
-        state_ids = yield self.get_current_state_ids(room_id)
-        create_id = state_ids.get((EventTypes.Create, ""))
-
-        # If we can't find the create event, assume we've hit a dead end
-        if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id))
-
-        # Retrieve the room's create event and return
-        create_event = yield self.get_event(create_id)
-        return create_event
-
-    @cached(max_entries=100000, iterable=True)
-    def get_current_state_ids(self, room_id):
-        """Get the current state event ids for a room based on the
-        current_state_events table.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            deferred: dict of (type, state_key) -> event_id
-        """
-
-        def _get_current_state_ids_txn(txn):
-            txn.execute(
-                """SELECT type, state_key, event_id FROM current_state_events
-                WHERE room_id = ?
-                """,
-                (room_id,),
-            )
-
-            return {
-                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
-            }
-
-        return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
-
-    # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
-        """Get the current state event of a given type for a room based on the
-        current_state_events table.  This may not be as up-to-date as the result
-        of doing a fresh state resolution as per state_handler.get_current_state
-
-        Args:
-            room_id (str)
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
-            event ID.
-        """
-
-        where_clause, where_args = state_filter.make_sql_filter_clause()
-
-        if not where_clause:
-            # We delegate to the cached version
-            return self.get_current_state_ids(room_id)
-
-        def _get_filtered_current_state_ids_txn(txn):
-            results = {}
-            sql = """
-                SELECT type, state_key, event_id FROM current_state_events
-                WHERE room_id = ?
-            """
-
-            if where_clause:
-                sql += " AND (%s)" % (where_clause,)
+    def __init__(self, hs, stores):
+        self.stores = stores
 
-            args = [room_id]
-            args.extend(where_args)
-            txn.execute(sql, args)
-            for row in txn:
-                typ, state_key, event_id = row
-                key = (intern_string(typ), intern_string(state_key))
-                results[key] = event_id
-
-            return results
-
-        return self.runInteraction(
-            "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
-        )
-
-    @defer.inlineCallbacks
-    def get_canonical_alias_for_room(self, room_id):
-        """Get canonical alias for room, if any
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[str|None]: The canonical alias, if any
-        """
-
-        state = yield self.get_filtered_current_state_ids(
-            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
-        )
-
-        event_id = state.get((EventTypes.CanonicalAlias, ""))
-        if not event_id:
-            return
-
-        event = yield self.get_event(event_id, allow_none=True)
-        if not event:
-            return
-
-        return event.content.get("canonical_alias")
-
-    @cached(max_entries=10000, iterable=True)
-    def get_state_group_delta(self, state_group):
+    def get_state_group_delta(self, state_group: int):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
         Returns:
-            (prev_group, delta_ids), where both may be None.
+            Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
+                (prev_group, delta_ids)
         """
 
-        def _get_state_group_delta_txn(txn):
-            prev_group = self._simple_select_one_onecol_txn(
-                txn,
-                table="state_group_edges",
-                keyvalues={"state_group": state_group},
-                retcol="prev_state_group",
-                allow_none=True,
-            )
-
-            if not prev_group:
-                return _GetStateGroupDelta(None, None)
-
-            delta_ids = self._simple_select_list_txn(
-                txn,
-                table="state_groups_state",
-                keyvalues={"state_group": state_group},
-                retcols=("type", "state_key", "event_id"),
-            )
-
-            return _GetStateGroupDelta(
-                prev_group,
-                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
-            )
-
-        return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+        return self.stores.state.get_state_group_delta(state_group)
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, _room_id, event_ids):
@@ -609,16 +358,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             event_ids (iterable[str]): ids of the events
 
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
+            Deferred[dict[int, StateMap[str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
             return {}
 
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
+        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups)
+        group_to_state = yield self.stores.state._get_state_for_groups(groups)
 
         return group_to_state
 
@@ -639,7 +388,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     @defer.inlineCallbacks
     def get_state_groups(self, room_id, event_ids):
         """ Get the state groups for the given list of event_ids
-
         Returns:
             Deferred[dict[int, list[EventBase]]]:
                 dict of state_group_id -> list of state events.
@@ -649,7 +397,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
 
-        state_event_map = yield self.get_events(
+        state_event_map = yield self.stores.main.get_events(
             [
                 ev_id
                 for group_ids in itervalues(group_to_ids)
@@ -667,153 +415,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             for group, event_id_map in iteritems(group_to_ids)
         }
 
-    @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, state_filter):
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
-        results = {}
-
-        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
-        for chunk in chunks:
-            res = yield self.runInteraction(
-                "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn,
-                chunk,
-                state_filter,
-            )
-            results.update(res)
 
-        return results
-
-    def _get_state_groups_from_groups_txn(
-        self, txn, groups, state_filter=StateFilter.all()
-    ):
-        results = {group: {} for group in groups}
-
-        where_clause, where_args = state_filter.make_sql_filter_clause()
-
-        # Unless the filter clause is empty, we're going to append it after an
-        # existing where clause
-        if where_clause:
-            where_clause = " AND (%s)" % (where_clause,)
-
-        if isinstance(self.database_engine, PostgresEngine):
-            # Temporarily disable sequential scans in this transaction. This is
-            # a temporary hack until we can add the right indices in
-            txn.execute("SET LOCAL enable_seqscan=off")
-
-            # The below query walks the state_group tree so that the "state"
-            # table includes all state_groups in the tree. It then joins
-            # against `state_groups_state` to fetch the latest state.
-            # It assumes that previous state groups are always numerically
-            # lesser.
-            # The PARTITION is used to get the event_id in the greatest state
-            # group for the given type, state_key.
-            # This may return multiple rows per (type, state_key), but last_value
-            # should be the same.
-            sql = """
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT DISTINCT type, state_key, last_value(event_id) OVER (
-                    PARTITION BY type, state_key ORDER BY state_group ASC
-                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
-                ) AS event_id FROM state_groups_state
-                WHERE state_group IN (
-                    SELECT state_group FROM state
-                )
-            """
-
-            for group in groups:
-                args = [group]
-                args.extend(where_args)
-
-                txn.execute(sql + where_clause, args)
-                for row in txn:
-                    typ, state_key, event_id = row
-                    key = (typ, state_key)
-                    results[group][key] = event_id
-        else:
-            max_entries_returned = state_filter.max_entries_returned()
-
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            for group in groups:
-                next_group = group
-
-                while next_group:
-                    # We did this before by getting the list of group ids, and
-                    # then passing that list to sqlite to get latest event for
-                    # each (type, state_key). However, that was terribly slow
-                    # without the right indices (which we can't add until
-                    # after we finish deduping state, which requires this func)
-                    args = [next_group]
-                    args.extend(where_args)
-
-                    txn.execute(
-                        "SELECT type, state_key, event_id FROM state_groups_state"
-                        " WHERE state_group = ? " + where_clause,
-                        args,
-                    )
-                    results[group].update(
-                        ((typ, state_key), event_id)
-                        for typ, state_key, event_id in txn
-                        if (typ, state_key) not in results[group]
-                    )
-
-                    # If the number of entries in the (type,state_key)->event_id dict
-                    # matches the number of (type,state_keys) types we were searching
-                    # for, then we must have found them all, so no need to go walk
-                    # further down the tree... UNLESS our types filter contained
-                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
-                    # search
-                    if (
-                        max_entries_returned is not None
-                        and len(results[group]) == max_entries_returned
-                    ):
-                        break
-
-                    next_group = self._simple_select_one_onecol_txn(
-                        txn,
-                        table="state_group_edges",
-                        keyvalues={"state_group": next_group},
-                        retcol="prev_state_group",
-                        allow_none=True,
-                    )
-
-        return results
+        return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     @defer.inlineCallbacks
     def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
-
         Args:
             event_ids (list[string])
             state_filter (StateFilter): The state filter used to fetch state
                 from the database.
-
         Returns:
             deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
         """
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
+        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, state_filter)
+        group_to_state = yield self.stores.state._get_state_for_groups(
+            groups, state_filter
+        )
 
-        state_event_map = yield self.get_events(
+        state_event_map = yield self.stores.main.get_events(
             [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
             get_prev_content=False,
         )
@@ -843,10 +479,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred dict from event_id -> (type, state_key) -> event_id
         """
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
+        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, state_filter)
+        group_to_state = yield self.stores.state._get_state_for_groups(
+            groups, state_filter
+        )
 
         event_to_state = {
             event_id: group_to_state[group]
@@ -887,77 +525,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         state_map = yield self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
 
-    @cached(max_entries=50000)
-    def _get_state_group_for_event(self, event_id):
-        return self._simple_select_one_onecol(
-            table="event_to_state_groups",
-            keyvalues={"event_id": event_id},
-            retcol="state_group",
-            allow_none=True,
-            desc="_get_state_group_for_event",
-        )
-
-    @cachedList(
-        cached_method_name="_get_state_group_for_event",
-        list_name="event_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def _get_state_group_for_events(self, event_ids):
-        """Returns mapping event_id -> state_group
-        """
-        rows = yield self._simple_select_many_batch(
-            table="event_to_state_groups",
-            column="event_id",
-            iterable=event_ids,
-            keyvalues={},
-            retcols=("event_id", "state_group"),
-            desc="_get_state_group_for_events",
-        )
-
-        return {row["event_id"]: row["state_group"] for row in rows}
-
-    def _get_state_for_group_using_cache(self, cache, group, state_filter):
-        """Checks if group is in cache. See `_get_state_for_groups`
-
-        Args:
-            cache(DictionaryCache): the state group cache to use
-            group(int): The state group to lookup
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns 2-tuple (`state_dict`, `got_all`).
-        `got_all` is a bool indicating if we successfully retrieved all
-        requests state from the cache, if False we need to query the DB for the
-        missing state.
-        """
-        is_all, known_absent, state_dict_ids = cache.get(group)
-
-        if is_all or state_filter.is_full():
-            # Either we have everything or want everything, either way
-            # `is_all` tells us whether we've gotten everything.
-            return state_filter.filter_state(state_dict_ids), is_all
-
-        # tracks whether any of our requested types are missing from the cache
-        missing_types = False
-
-        if state_filter.has_wildcards():
-            # We don't know if we fetched all the state keys for the types in
-            # the filter that are wildcards, so we have to assume that we may
-            # have missed some.
-            missing_types = True
-        else:
-            # There aren't any wild cards, so `concrete_types()` returns the
-            # complete list of event types we're wanting.
-            for key in state_filter.concrete_types():
-                if key not in state_dict_ids and key not in known_absent:
-                    missing_types = True
-                    break
-
-        return state_filter.filter_state(state_dict_ids), not missing_types
-
-    @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
@@ -967,157 +537,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             state_filter (StateFilter): The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
         """
-
-        member_filter, non_member_filter = state_filter.get_member_split()
-
-        # Now we look them up in the member and non-member caches
-        non_member_state, incomplete_groups_nm, = (
-            yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_cache, state_filter=non_member_filter
-            )
-        )
-
-        member_state, incomplete_groups_m, = (
-            yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_members_cache, state_filter=member_filter
-            )
-        )
-
-        state = dict(non_member_state)
-        for group in groups:
-            state[group].update(member_state[group])
-
-        # Now fetch any missing groups from the database
-
-        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
-
-        if not incomplete_groups:
-            return state
-
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
-
-        # Help the cache hit ratio by expanding the filter a bit
-        db_state_filter = state_filter.return_expanded()
-
-        group_to_state_dict = yield self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
-
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
-        )
-
-        # And finally update the result dict, by filtering out any extra
-        # stuff we pulled out of the database.
-        for group, group_state_dict in iteritems(group_to_state_dict):
-            # We just replace any existing entries, as we will have loaded
-            # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
-
-        return state
-
-    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
-        """Gets the state at each of a list of state groups, optionally
-        filtering by type/state_key, querying from a specific cache.
-
-        Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            cache (DictionaryCache): the cache of group ids to state dicts which
-                we will pass through - either the normal state cache or the specific
-                members state cache.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
-            dict of state_group_id -> (dict of (type, state_key) -> event id)
-            of entries in the cache, and the state group ids either missing
-            from the cache or incomplete.
-        """
-        results = {}
-        incomplete_groups = set()
-        for group in set(groups):
-            state_dict_ids, got_all = self._get_state_for_group_using_cache(
-                cache, group, state_filter
-            )
-            results[group] = state_dict_ids
-
-            if not got_all:
-                incomplete_groups.add(group)
-
-        return results, incomplete_groups
-
-    def _insert_into_cache(
-        self,
-        group_to_state_dict,
-        state_filter,
-        cache_seq_num_members,
-        cache_seq_num_non_members,
-    ):
-        """Inserts results from querying the database into the relevant cache.
-
-        Args:
-            group_to_state_dict (dict): The new entries pulled from database.
-                Map from state group to state dict
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-            cache_seq_num_members (int): Sequence number of member cache since
-                last lookup in cache
-            cache_seq_num_non_members (int): Sequence number of member cache since
-                last lookup in cache
-        """
-
-        # We need to work out which types we've fetched from the DB for the
-        # member vs non-member caches. This should be as accurate as possible,
-        # but can be an underestimate (e.g. when we have wild cards)
-
-        member_filter, non_member_filter = state_filter.get_member_split()
-        if member_filter.is_full():
-            # We fetched all member events
-            member_types = None
-        else:
-            # `concrete_types()` will only return a subset when there are wild
-            # cards in the filter, but that's fine.
-            member_types = member_filter.concrete_types()
-
-        if non_member_filter.is_full():
-            # We fetched all non member events
-            non_member_types = None
-        else:
-            non_member_types = non_member_filter.concrete_types()
-
-        for group, group_state_dict in iteritems(group_to_state_dict):
-            state_dict_members = {}
-            state_dict_non_members = {}
-
-            for k, v in iteritems(group_state_dict):
-                if k[0] == EventTypes.Member:
-                    state_dict_members[k] = v
-                else:
-                    state_dict_non_members[k] = v
-
-            self._state_group_members_cache.update(
-                cache_seq_num_members,
-                key=group,
-                value=state_dict_members,
-                fetched_keys=member_types,
-            )
-
-            self._state_group_cache.update(
-                cache_seq_num_non_members,
-                key=group,
-                value=state_dict_non_members,
-                fetched_keys=non_member_types,
-            )
+        return self.stores.state._get_state_for_groups(groups, state_filter)
 
     def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
@@ -1137,393 +559,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             Deferred[int]: The state group ID
         """
-
-        def _store_state_group_txn(txn):
-            if current_state_ids is None:
-                # AFAIK, this can never happen
-                raise Exception("current_state_ids cannot be None")
-
-            state_group = self.database_engine.get_next_state_group_id(txn)
-
-            self._simple_insert_txn(
-                txn,
-                table="state_groups",
-                values={"id": state_group, "room_id": room_id, "event_id": event_id},
-            )
-
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if prev_group:
-                is_in_db = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                self._simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={"state_group": state_group, "prev_state_group": prev_group},
-                )
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in iteritems(delta_ids)
-                    ],
-                )
-            else:
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in iteritems(current_state_ids)
-                    ],
-                )
-
-            # Prefill the state group caches with this group.
-            # It's fine to use the sequence like this as the state group map
-            # is immutable. (If the map wasn't immutable then this prefill could
-            # race with another update)
-
-            current_member_state_ids = {
-                s: ev
-                for (s, ev) in iteritems(current_state_ids)
-                if s[0] == EventTypes.Member
-            }
-            txn.call_after(
-                self._state_group_members_cache.update,
-                self._state_group_members_cache.sequence,
-                key=state_group,
-                value=dict(current_member_state_ids),
-            )
-
-            current_non_member_state_ids = {
-                s: ev
-                for (s, ev) in iteritems(current_state_ids)
-                if s[0] != EventTypes.Member
-            }
-            txn.call_after(
-                self._state_group_cache.update,
-                self._state_group_cache.sequence,
-                key=state_group,
-                value=dict(current_non_member_state_ids),
-            )
-
-            return state_group
-
-        return self.runInteraction("store_state_group", _store_state_group_txn)
-
-    def _count_state_group_hops_txn(self, txn, state_group):
-        """Given a state group, count how many hops there are in the tree.
-
-        This is used to ensure the delta chains don't get too long.
-        """
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = """
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT count(*) FROM state;
-            """
-
-            txn.execute(sql, (state_group,))
-            row = txn.fetchone()
-            if row and row[0]:
-                return row[0]
-            else:
-                return 0
-        else:
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
-            count = 0
-
-            while next_group:
-                next_group = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_group_edges",
-                    keyvalues={"state_group": next_group},
-                    retcol="prev_state_group",
-                    allow_none=True,
-                )
-                if next_group:
-                    count += 1
-
-            return count
-
-
-class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
-    """ Keeps track of the state at a given event.
-
-    This is done by the concept of `state groups`. Every event is a assigned
-    a state group (identified by an arbitrary string), which references a
-    collection of state events. The current state of an event is then the
-    collection of state events referenced by the event's state group.
-
-    Hence, every change in the current state causes a new state group to be
-    generated. However, if no change happens (e.g., if we get a message event
-    with only one parent it inherits the state group from its parent.)
-
-    There are three tables:
-      * `state_groups`: Stores group name, first event with in the group and
-        room id.
-      * `event_to_state_groups`: Maps events to state groups.
-      * `state_groups_state`: Maps state group to state events.
-    """
-
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-    EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
-
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
-        )
-        self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
-        )
-        self.register_background_index_update(
-            self.CURRENT_STATE_INDEX_UPDATE_NAME,
-            index_name="current_state_events_member_index",
-            table="current_state_events",
-            columns=["state_key"],
-            where_clause="type='m.room.member'",
-        )
-        self.register_background_index_update(
-            self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
-            index_name="event_to_state_groups_sg_index",
-            table="event_to_state_groups",
-            columns=["state_group"],
-        )
-
-    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
-
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.prev_group
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-        self._simple_insert_many_txn(
-            txn,
-            table="event_to_state_groups",
-            values=[
-                {"state_group": state_group_id, "event_id": event_id}
-                for event_id, state_group_id in iteritems(state_groups)
-            ],
-        )
-
-        for event_id, state_group_id in iteritems(state_groups):
-            txn.call_after(
-                self._get_state_group_for_event.prefill, (event_id,), state_group_id
-            )
-
-    @defer.inlineCallbacks
-    def _background_deduplicate_state(self, progress, batch_size):
-        """This background update will slowly deduplicate state by reencoding
-        them as deltas.
-        """
-        last_state_group = progress.get("last_state_group", 0)
-        rows_inserted = progress.get("rows_inserted", 0)
-        max_group = progress.get("max_group", None)
-
-        BATCH_SIZE_SCALE_FACTOR = 100
-
-        batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
-
-        if max_group is None:
-            rows = yield self._execute(
-                "_background_deduplicate_state",
-                None,
-                "SELECT coalesce(max(id), 0) FROM state_groups",
-            )
-            max_group = rows[0][0]
-
-        def reindex_txn(txn):
-            new_last_state_group = last_state_group
-            for count in range(batch_size):
-                txn.execute(
-                    "SELECT id, room_id FROM state_groups"
-                    " WHERE ? < id AND id <= ?"
-                    " ORDER BY id ASC"
-                    " LIMIT 1",
-                    (new_last_state_group, max_group),
-                )
-                row = txn.fetchone()
-                if row:
-                    state_group, room_id = row
-
-                if not row or not state_group:
-                    return True, count
-
-                txn.execute(
-                    "SELECT state_group FROM state_group_edges"
-                    " WHERE state_group = ?",
-                    (state_group,),
-                )
-
-                # If we reach a point where we've already started inserting
-                # edges we should stop.
-                if txn.fetchall():
-                    return True, count
-
-                txn.execute(
-                    "SELECT coalesce(max(id), 0) FROM state_groups"
-                    " WHERE id < ? AND room_id = ?",
-                    (state_group, room_id),
-                )
-                prev_group, = txn.fetchone()
-                new_last_state_group = state_group
-
-                if prev_group:
-                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-                    if potential_hops >= MAX_STATE_DELTA_HOPS:
-                        # We want to ensure chains are at most this long,#
-                        # otherwise read performance degrades.
-                        continue
-
-                    prev_state = self._get_state_groups_from_groups_txn(
-                        txn, [prev_group]
-                    )
-                    prev_state = prev_state[prev_group]
-
-                    curr_state = self._get_state_groups_from_groups_txn(
-                        txn, [state_group]
-                    )
-                    curr_state = curr_state[state_group]
-
-                    if not set(prev_state.keys()) - set(curr_state.keys()):
-                        # We can only do a delta if the current has a strict super set
-                        # of keys
-
-                        delta_state = {
-                            key: value
-                            for key, value in iteritems(curr_state)
-                            if prev_state.get(key, None) != value
-                        }
-
-                        self._simple_delete_txn(
-                            txn,
-                            table="state_group_edges",
-                            keyvalues={"state_group": state_group},
-                        )
-
-                        self._simple_insert_txn(
-                            txn,
-                            table="state_group_edges",
-                            values={
-                                "state_group": state_group,
-                                "prev_state_group": prev_group,
-                            },
-                        )
-
-                        self._simple_delete_txn(
-                            txn,
-                            table="state_groups_state",
-                            keyvalues={"state_group": state_group},
-                        )
-
-                        self._simple_insert_many_txn(
-                            txn,
-                            table="state_groups_state",
-                            values=[
-                                {
-                                    "state_group": state_group,
-                                    "room_id": room_id,
-                                    "type": key[0],
-                                    "state_key": key[1],
-                                    "event_id": state_id,
-                                }
-                                for key, state_id in iteritems(delta_state)
-                            ],
-                        )
-
-            progress = {
-                "last_state_group": state_group,
-                "rows_inserted": rows_inserted + batch_size,
-                "max_group": max_group,
-            }
-
-            self._background_update_progress_txn(
-                txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
-            )
-
-            return False, batch_size
-
-        finished, result = yield self.runInteraction(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+        return self.stores.state.store_state_group(
+            event_id, room_id, prev_group, delta_ids, current_state_ids
         )
-
-        if finished:
-            yield self._end_background_update(
-                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
-            )
-
-        return result * BATCH_SIZE_SCALE_FACTOR
-
-    @defer.inlineCallbacks
-    def _background_index_state(self, progress, batch_size):
-        def reindex_txn(conn):
-            conn.rollback()
-            if isinstance(self.database_engine, PostgresEngine):
-                # postgres insists on autocommit for the index
-                conn.set_session(autocommit=True)
-                try:
-                    txn = conn.cursor()
-                    txn.execute(
-                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
-                        " ON state_groups_state(state_group, type, state_key)"
-                    )
-                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-                finally:
-                    conn.set_session(autocommit=False)
-            else:
-                txn = conn.cursor()
-                txn.execute(
-                    "CREATE INDEX state_groups_state_type_idx"
-                    " ON state_groups_state(state_group, type, state_key)"
-                )
-                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-
-        yield self.runWithConnection(reindex_txn)
-
-        yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
-
-        return 1
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
new file mode 100644
index 0000000000..daff81c5ee
--- /dev/null
+++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Iterable, Iterator, List, Tuple
+
+from typing_extensions import Protocol
+
+
+"""
+Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
+"""
+
+
+class Cursor(Protocol):
+    def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+        ...
+
+    def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+        ...
+
+    def fetchall(self) -> List[Tuple]:
+        ...
+
+    def fetchone(self) -> Tuple:
+        ...
+
+    @property
+    def description(self) -> Any:
+        return None
+
+    @property
+    def rowcount(self) -> int:
+        return 0
+
+    def __iter__(self) -> Iterator[Tuple]:
+        ...
+
+    def close(self) -> None:
+        ...
+
+
+class Connection(Protocol):
+    def cursor(self) -> Cursor:
+        ...
+
+    def close(self) -> None:
+        ...
+
+    def commit(self) -> None:
+        ...
+
+    def rollback(self, *args, **kwargs) -> None:
+        ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index cbb0a4810a..f89ce0bed2 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
 import contextlib
 import threading
 from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
 
 
 class IdGenerator(object):
@@ -46,7 +51,7 @@ def _load_current_id(db_conn, table, column, step=1):
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
     else:
         cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
-    val, = cur.fetchone()
+    (val,) = cur.fetchone()
     cur.close()
     current_id = int(val) if val else step
     return (max if step > 0 else min)(current_id, step)
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[int]
 
     def get_next(self):
         """
@@ -161,9 +166,10 @@ class ChainedIdGenerator(object):
 
     def __init__(self, chained_generator, db_conn, table, column):
         self.chained_generator = chained_generator
+        self._table = table
         self._lock = threading.Lock()
         self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
 
     def get_next(self):
         """
@@ -198,3 +204,173 @@ class ChainedIdGenerator(object):
                 return stream_id - 1, chained_id
 
             return self._current_max, self.chained_generator.get_current_token()
+
+    def advance(self, token: int):
+        """Stub implementation for advancing the token when receiving updates
+        over replication; raises an exception as this instance should be the
+        only source of updates.
+        """
+
+        raise Exception(
+            "Attempted to advance token on source for table %r", self._table
+        )
+
+
+class MultiWriterIdGenerator:
+    """An ID generator that tracks a stream that can have multiple writers.
+
+    Uses a Postgres sequence to coordinate ID assignment, but positions of other
+    writers will only get updated when `advance` is called (by replication).
+
+    Note: Only works with Postgres.
+
+    Args:
+        db_conn
+        db
+        instance_name: The name of this instance.
+        table: Database table associated with stream.
+        instance_column: Column that stores the row's writer's instance name
+        id_column: Column that stores the stream ID.
+        sequence_name: The name of the postgres sequence used to generate new
+            IDs.
+    """
+
+    def __init__(
+        self,
+        db_conn,
+        db: Database,
+        instance_name: str,
+        table: str,
+        instance_column: str,
+        id_column: str,
+        sequence_name: str,
+    ):
+        self._db = db
+        self._instance_name = instance_name
+        self._sequence_name = sequence_name
+
+        # We lock as some functions may be called from DB threads.
+        self._lock = threading.Lock()
+
+        self._current_positions = self._load_current_ids(
+            db_conn, table, instance_column, id_column
+        )
+
+        # Set of local IDs that we're still processing. The current position
+        # should be less than the minimum of this set (if not empty).
+        self._unfinished_ids = set()  # type: Set[int]
+
+    def _load_current_ids(
+        self, db_conn, table: str, instance_column: str, id_column: str
+    ) -> Dict[str, int]:
+        sql = """
+            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            GROUP BY %(instance)s
+        """ % {
+            "instance": instance_column,
+            "id": id_column,
+            "table": table,
+        }
+
+        cur = db_conn.cursor()
+        cur.execute(sql)
+
+        # `cur` is an iterable over returned rows, which are 2-tuples.
+        current_positions = dict(cur)
+
+        cur.close()
+
+        return current_positions
+
+    def _load_next_id_txn(self, txn):
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        (next_id,) = txn.fetchone()
+        return next_id
+
+    async def get_next(self):
+        """
+        Usage:
+            with await stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+        # Assert the fetched ID is actually greater than what we currently
+        # believe the ID to be. If not, then the sequence and table have got
+        # out of sync somehow.
+        assert self.get_current_token() < next_id
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_id
+            finally:
+                self._mark_id_as_finished(next_id)
+
+        return manager()
+
+    def get_next_txn(self, txn: LoggingTransaction):
+        """
+        Usage:
+
+            stream_id = stream_id_gen.get_next(txn)
+            # ... persist event ...
+        """
+
+        next_id = self._load_next_id_txn(txn)
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        txn.call_after(self._mark_id_as_finished, next_id)
+        txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+        return next_id
+
+    def _mark_id_as_finished(self, next_id: int):
+        """The ID has finished being processed so we should advance the
+        current poistion if possible.
+        """
+
+        with self._lock:
+            self._unfinished_ids.discard(next_id)
+
+            # Figure out if its safe to advance the position by checking there
+            # aren't any lower allocated IDs that are yet to finish.
+            if all(c > next_id for c in self._unfinished_ids):
+                curr = self._current_positions.get(self._instance_name, 0)
+                self._current_positions[self._instance_name] = max(curr, next_id)
+
+    def get_current_token(self, instance_name: str = None) -> int:
+        """Gets the current position of a named writer (defaults to current
+        instance).
+
+        Returns 0 if we don't have a position for the named writer (likely due
+        to it being a new writer).
+        """
+
+        if instance_name is None:
+            instance_name = self._instance_name
+
+        with self._lock:
+            return self._current_positions.get(instance_name, 0)
+
+    def get_positions(self) -> Dict[str, int]:
+        """Get a copy of the current positon map.
+        """
+
+        with self._lock:
+            return dict(self._current_positions)
+
+    def advance(self, instance_name: str, new_id: int):
+        """Advance the postion of the named writer to the given ID, if greater
+        than existing entry.
+        """
+
+        with self._lock:
+            self._current_positions[instance_name] = max(
+                new_id, self._current_positions.get(instance_name, 0)
+            )