summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py2
-rw-r--r--synapse/app/_base.py10
-rw-r--r--synapse/app/admin_cmd.py11
-rw-r--r--synapse/app/appservice.py5
-rw-r--r--synapse/app/client_reader.py5
-rw-r--r--synapse/app/event_creator.py5
-rw-r--r--synapse/app/federation_reader.py5
-rw-r--r--synapse/app/federation_sender.py5
-rw-r--r--synapse/app/frontend_proxy.py5
-rw-r--r--synapse/app/homeserver.py22
-rw-r--r--synapse/app/media_repository.py5
-rw-r--r--synapse/app/pusher.py5
-rw-r--r--synapse/app/synchrotron.py5
-rw-r--r--synapse/app/user_dir.py5
-rw-r--r--synapse/config/database.py78
-rw-r--r--synapse/config/emailconfig.py3
-rw-r--r--synapse/config/key.py23
-rw-r--r--synapse/config/ratelimiting.py3
-rw-r--r--synapse/config/saml2_config.py186
-rw-r--r--synapse/config/server.py15
-rw-r--r--synapse/event_auth.py8
-rw-r--r--synapse/events/snapshot.py36
-rw-r--r--synapse/events/third_party_rules.py2
-rw-r--r--synapse/federation/federation_client.py88
-rw-r--r--synapse/federation/federation_server.py15
-rw-r--r--synapse/federation/transport/client.py33
-rw-r--r--synapse/federation/transport/server.py32
-rw-r--r--synapse/groups/groups_server.py5
-rw-r--r--synapse/handlers/_base.py2
-rw-r--r--synapse/handlers/account_data.py15
-rw-r--r--synapse/handlers/account_validity.py86
-rw-r--r--synapse/handlers/admin.py41
-rw-r--r--synapse/handlers/deactivate_account.py54
-rw-r--r--synapse/handlers/e2e_keys.py35
-rw-r--r--synapse/handlers/federation.py221
-rw-r--r--synapse/handlers/initial_sync.py115
-rw-r--r--synapse/handlers/message.py15
-rw-r--r--synapse/handlers/pagination.py27
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/handlers/room.py14
-rw-r--r--synapse/handlers/room_member.py4
-rw-r--r--synapse/handlers/saml_handler.py198
-rw-r--r--synapse/handlers/search.py34
-rw-r--r--synapse/handlers/typing.py3
-rw-r--r--synapse/logging/_terse_json.py2
-rw-r--r--synapse/logging/context.py14
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/push/pusher.py12
-rw-r--r--synapse/push/pusherpool.py10
-rw-r--r--synapse/replication/http/federation.py5
-rw-r--r--synapse/replication/http/send_event.py3
-rw-r--r--synapse/rest/client/v1/pusher.py33
-rw-r--r--synapse/server.py39
-rw-r--r--synapse/state/__init__.py3
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/data_stores/__init__.py45
-rw-r--r--synapse/storage/data_stores/main/client_ips.py10
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py17
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py217
-rw-r--r--synapse/storage/data_stores/main/events.py157
-rw-r--r--synapse/storage/data_stores/main/events_worker.py97
-rw-r--r--synapse/storage/data_stores/main/push_rule.py2
-rw-r--r--synapse/storage/data_stores/main/pusher.py25
-rw-r--r--synapse/storage/data_stores/main/roommember.py2
-rw-r--r--synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql20
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql29
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres52
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite6
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.md13
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.txt19
-rw-r--r--synapse/storage/data_stores/main/search.py19
-rw-r--r--synapse/storage/data_stores/main/state.py955
-rw-r--r--synapse/storage/data_stores/state/__init__.py16
-rw-r--r--synapse/storage/data_stores/state/bg_updates.py374
-rw-r--r--synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql (renamed from synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/30/state_stream.sql (renamed from synapse/storage/data_stores/main/schema/delta/30/state_stream.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql19
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql (renamed from synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/state.sql (renamed from synapse/storage/data_stores/main/schema/delta/35/state.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql (renamed from synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py (renamed from synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql17
-rw-r--r--synapse/storage/data_stores/state/schema/full_schemas/54/full.sql37
-rw-r--r--synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres21
-rw-r--r--synapse/storage/data_stores/state/store.py640
-rw-r--r--synapse/storage/database.py45
-rw-r--r--synapse/storage/engines/sqlite.py6
-rw-r--r--synapse/storage/persist_events.py2
-rw-r--r--synapse/storage/prepare_database.py30
-rw-r--r--synapse/storage/purge_events.py4
-rw-r--r--synapse/storage/state.py14
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/util/caches/descriptors.py2
-rw-r--r--synapse/util/caches/snapshot_cache.py94
97 files changed, 2603 insertions, 2030 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9fd52a8c77..abbc7079a3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -79,7 +79,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def check_from_context(self, room_version, event, context, do_sig_check=True):
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         auth_events_ids = yield self.compute_auth_events(
             event, prev_state_ids, for_verification=True
         )
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 9c96816096..0e8b467a3e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -237,6 +237,12 @@ def start(hs, listeners=None):
     """
     Start a Synapse server or worker.
 
+    Should be called once the reactor is running and (if we're using ACME) the
+    TLS certificates are in place.
+
+    Will start the main HTTP listeners and do some other startup tasks, and then
+    notify systemd.
+
     Args:
         hs (synapse.server.HomeServer)
         listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
@@ -311,9 +317,7 @@ def setup_sdnotify(hs):
 
     # Tell systemd our state, if we're using it. This will silently fail if
     # we're not using systemd.
-    hs.get_reactor().addSystemEventTrigger(
-        "after", "startup", sdnotify, b"READY=1\nMAINPID=%i" % (os.getpid(),)
-    )
+    sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),))
 
     hs.get_reactor().addSystemEventTrigger(
         "before", "shutdown", sdnotify, b"STOPPING=1"
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 04751a6a5e..8e36bc57d3 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
 from synapse.util.logcontext import LoggingContext
 from synapse.util.versionstring import get_version_string
 
@@ -105,8 +104,10 @@ def export_data_command(hs, args):
     user_id = args.user_id
     directory = args.output_directory
 
-    res = yield hs.get_handlers().admin_handler.export_user_data(
-        user_id, FileExfiltrationWriter(user_id, directory=directory)
+    res = yield defer.ensureDeferred(
+        hs.get_handlers().admin_handler.export_user_data(
+            user_id, FileExfiltrationWriter(user_id, directory=directory)
+        )
     )
     print(res)
 
@@ -229,14 +230,10 @@ def start(config_options):
 
     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     ss = AdminCmdServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 02b900f382..e82e0f11e3 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
@@ -143,8 +142,6 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     if config.notify_appservices:
         sys.stderr.write(
             "\nThe appservices must be disabled in the main synapse process"
@@ -159,10 +156,8 @@ def start(config_options):
 
     ps = AppserviceServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ps, config, use_worker_options=True)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index dadb487d5f..3edfe19567 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
@@ -181,14 +180,10 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     ss = ClientReaderServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index d110599a35..d0ddbe38fc 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import (
 )
 from synapse.server import HomeServer
 from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
@@ -180,14 +179,10 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     ss = EventCreatorServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 418c086254..311523e0ed 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
@@ -162,14 +161,10 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     ss = FederationReaderServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index f24920a7d6..83c436229c 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.replication.tcp.streams._base import ReceiptsStream
 from synapse.server import HomeServer
 from synapse.storage.database import Database
-from synapse.storage.engines import create_engine
 from synapse.types import ReadReceipt
 from synapse.util.async_helpers import Linearizer
 from synapse.util.httpresourcetree import create_resource_tree
@@ -174,8 +173,6 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     if config.send_federation:
         sys.stderr.write(
             "\nThe send_federation must be disabled in the main synapse process"
@@ -190,10 +187,8 @@ def start(config_options):
 
     ss = FederationSenderServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index e647459d0e..30e435eead 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
@@ -234,14 +233,10 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     ss = FrontendProxyServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index df65d0a989..0e9bf7f53a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
 from synapse.storage import DataStore
-from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
+from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.prepare_database import UpgradeDatabaseException
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.httpresourcetree import create_resource_tree
@@ -328,15 +328,10 @@ def setup(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-    config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
-
     hs = SynapseHomeServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
@@ -347,13 +342,8 @@ def setup(config_options):
         hs.setup()
     except IncorrectDatabaseSetup as e:
         quit_with_error(str(e))
-    except UpgradeDatabaseException:
-        sys.stderr.write(
-            "\nFailed to upgrade database.\n"
-            "Have you checked for version specific instructions in"
-            " UPGRADES.rst?\n"
-        )
-        sys.exit(1)
+    except UpgradeDatabaseException as e:
+        quit_with_error("Failed to upgrade database: %s" % (e,))
 
     hs.setup_master()
 
@@ -519,8 +509,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
     # Database version
     #
 
-    stats["database_engine"] = hs.database_engine.module.__name__
-    stats["database_server_version"] = hs.database_engine.server_version
+    # This only reports info about the *main* database.
+    stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+    stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
     logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
     try:
         yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 2c6dd3ef02..4c80f257e2 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo
 from synapse.rest.media.v0.content_repository import ContentRepoResource
 from synapse.server import HomeServer
 from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
@@ -157,14 +156,10 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     ss = MediaRepositoryServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index dd52a9fc2d..09e639040a 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -37,7 +37,6 @@ from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.server import HomeServer
 from synapse.storage import DataStore
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
@@ -203,14 +202,10 @@ def start(config_options):
     # Force the pushers to start since they will be disabled in the main config
     config.start_pushers = True
 
-    database_engine = create_engine(config.database_config)
-
     ps = PusherServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ps, config, use_worker_options=True)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 288ee64b42..dd2132e608 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
 from synapse.rest.client.v2_alpha import sync
 from synapse.server import HomeServer
 from synapse.storage.data_stores.main.presence import UserPresenceState
-from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.stringutils import random_string
@@ -437,14 +436,10 @@ def start(config_options):
 
     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     ss = SynchrotronServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
         application_service_handler=SynchrotronApplicationService(),
     )
 
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index c01fb34a9b..1257098f92 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory
 from synapse.server import HomeServer
 from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
 from synapse.storage.database import Database
-from synapse.storage.engines import create_engine
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
@@ -200,8 +199,6 @@ def start(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
-    database_engine = create_engine(config.database_config)
-
     if config.update_user_directory:
         sys.stderr.write(
             "\nThe update_user_directory must be disabled in the main synapse process"
@@ -216,10 +213,8 @@ def start(config_options):
 
     ss = UserDirectoryServer(
         config.server_name,
-        db_config=config.database_config,
         config=config,
         version_string="Synapse/" + get_version_string(synapse),
-        database_engine=database_engine,
     )
 
     setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 0e2509f0b1..134824789c 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,12 +12,45 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 import os
 from textwrap import indent
+from typing import List
 
 import yaml
 
-from ._base import Config
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DatabaseConnectionConfig:
+    """Contains the connection config for a particular database.
+
+    Args:
+        name: A label for the database, used for logging.
+        db_config: The config for a particular database, as per `database`
+            section of main config. Has two fields: `name` for database
+            module name, and `args` for the args to give to the database
+            connector.
+        data_stores: The list of data stores that should be provisioned on the
+            database. Defaults to all data stores.
+    """
+
+    def __init__(
+        self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"]
+    ):
+        if db_config["name"] not in ("sqlite3", "psycopg2"):
+            raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+
+        if db_config["name"] == "sqlite3":
+            db_config.setdefault("args", {}).update(
+                {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+            )
+
+        self.name = name
+        self.config = db_config
+        self.data_stores = data_stores
 
 
 class DatabaseConfig(Config):
@@ -26,20 +59,12 @@ class DatabaseConfig(Config):
     def read_config(self, config, **kwargs):
         self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
 
-        self.database_config = config.get("database")
+        database_config = config.get("database")
 
-        if self.database_config is None:
-            self.database_config = {"name": "sqlite3", "args": {}}
+        if database_config is None:
+            database_config = {"name": "sqlite3", "args": {}}
 
-        name = self.database_config.get("name", None)
-        if name == "psycopg2":
-            pass
-        elif name == "sqlite3":
-            self.database_config.setdefault("args", {}).update(
-                {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
-            )
-        else:
-            raise RuntimeError("Unsupported database type '%s'" % (name,))
+        self.databases = [DatabaseConnectionConfig("master", database_config)]
 
         self.set_databasepath(config.get("database_path"))
 
@@ -76,11 +101,24 @@ class DatabaseConfig(Config):
         self.set_databasepath(args.database_path)
 
     def set_databasepath(self, database_path):
+        if database_path is None:
+            return
+
         if database_path != ":memory:":
             database_path = self.abspath(database_path)
-        if self.database_config.get("name", None) == "sqlite3":
-            if database_path is not None:
-                self.database_config["args"]["database"] = database_path
+
+        # We only support setting a database path if we have a single sqlite3
+        # database.
+        if len(self.databases) != 1:
+            raise ConfigError("Cannot specify 'database_path' with multiple databases")
+
+        database = self.get_single_database()
+        if database.config["name"] != "sqlite3":
+            # We don't raise here as we haven't done so before for this case.
+            logger.warn("Ignoring 'database_path' for non-sqlite3 database")
+            return
+
+        database.config["args"]["database"] = database_path
 
     @staticmethod
     def add_arguments(parser):
@@ -91,3 +129,11 @@ class DatabaseConfig(Config):
             metavar="SQLITE_DATABASE_PATH",
             help="The path to a sqlite database to use.",
         )
+
+    def get_single_database(self) -> DatabaseConnectionConfig:
+        """Returns the database if there is only one, useful for e.g. tests
+        """
+        if len(self.databases) != 1:
+            raise Exception("More than one database exists")
+
+        return self.databases[0]
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 18f42a87f9..35756bed87 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -21,6 +21,7 @@ from __future__ import print_function
 import email.utils
 import os
 from enum import Enum
+from typing import Optional
 
 import pkg_resources
 
@@ -101,7 +102,7 @@ class EmailConfig(Config):
                 # both in RegistrationConfig and here. We should factor this bit out
                 self.account_threepid_delegate_email = self.trusted_third_party_id_servers[
                     0
-                ]
+                ]  # type: Optional[str]
                 self.using_identity_server_from_trusted_list = True
             else:
                 raise ConfigError(
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 52ff1b2621..066e7838c3 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -108,7 +108,7 @@ class KeyConfig(Config):
             self.signing_key = self.read_signing_keys(signing_key_path, "signing_key")
 
         self.old_signing_keys = self.read_old_signing_keys(
-            config.get("old_signing_keys", {})
+            config.get("old_signing_keys")
         )
         self.key_refresh_interval = self.parse_duration(
             config.get("key_refresh_interval", "1d")
@@ -199,14 +199,19 @@ class KeyConfig(Config):
         signing_key_path: "%(base_key_name)s.signing.key"
 
         # The keys that the server used to sign messages with but won't use
-        # to sign new messages. E.g. it has lost its private key
+        # to sign new messages.
         #
-        #old_signing_keys:
-        #  "ed25519:auto":
-        #    # Base64 encoded public key
-        #    key: "The public part of your old signing key."
-        #    # Millisecond POSIX timestamp when the key expired.
-        #    expired_ts: 123456789123
+        old_signing_keys:
+          # For each key, `key` should be the base64-encoded public key, and
+          # `expired_ts`should be the time (in milliseconds since the unix epoch) that
+          # it was last used.
+          #
+          # It is possible to build an entry from an old signing.key file using the
+          # `export_signing_key` script which is provided with synapse.
+          #
+          # For example:
+          #
+          #"ed25519:id": { key: "base64string", expired_ts: 123456789123 }
 
         # How long key response published by this server is valid for.
         # Used to set the valid_until_ts in /key/v2 APIs.
@@ -290,6 +295,8 @@ class KeyConfig(Config):
             raise ConfigError("Error reading %s: %s" % (name, str(e)))
 
     def read_old_signing_keys(self, old_signing_keys):
+        if old_signing_keys is None:
+            return {}
         keys = {}
         for key_id, key_data in old_signing_keys.items():
             if is_signing_algorithm_supported(key_id):
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 947f653e03..4a3bfc4354 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -83,10 +83,9 @@ class RatelimitConfig(Config):
         )
 
         rc_admin_redaction = config.get("rc_admin_redaction")
+        self.rc_admin_redaction = None
         if rc_admin_redaction:
             self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
-        else:
-            self.rc_admin_redaction = None
 
     def generate_config_section(self, **kwargs):
         return """\
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
+import logging
 
 from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
-    map_username_to_mxid_localpart,
-    mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
 
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+    "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
 
 def _dict_merge(merge_dict, into_dict):
     """Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
-        self.saml2_mxid_source_attribute = saml2_config.get(
-            "mxid_source_attribute", "uid"
-        )
-
         self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
             "grandfathered_mxid_source_attribute", "uid"
         )
 
-        saml2_config_dict = self._default_saml_config_dict()
+        # user_mapping_provider may be None if the key is present but has no value
+        ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+        # Use the default user mapping provider if not set
+        ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+        # Ensure a config is present
+        ump_dict["config"] = ump_dict.get("config") or {}
+
+        if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+            # Load deprecated options for use by the default module
+            old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+            if old_mxid_source_attribute:
+                logger.warning(
+                    "The config option saml2_config.mxid_source_attribute is deprecated. "
+                    "Please use saml2_config.user_mapping_provider.config"
+                    ".mxid_source_attribute instead."
+                )
+                ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+            old_mxid_mapping = saml2_config.get("mxid_mapping")
+            if old_mxid_mapping:
+                logger.warning(
+                    "The config option saml2_config.mxid_mapping is deprecated. Please "
+                    "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+                )
+                ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+        # Retrieve an instance of the module's class
+        # Pass the config dictionary to the module for processing
+        (
+            self.saml2_user_mapping_provider_class,
+            self.saml2_user_mapping_provider_config,
+        ) = load_module(ump_dict)
+
+        # Ensure loaded user mapping module has defined all necessary methods
+        # Note parse_config() is already checked during the call to load_module
+        required_methods = [
+            "get_saml_attributes",
+            "saml_response_to_user_attributes",
+        ]
+        missing_methods = [
+            method
+            for method in required_methods
+            if not hasattr(self.saml2_user_mapping_provider_class, method)
+        ]
+        if missing_methods:
+            raise ConfigError(
+                "Class specified by saml2_config."
+                "user_mapping_provider.module is missing required "
+                "methods: %s" % (", ".join(missing_methods),)
+            )
+
+        # Get the desired saml auth response attributes from the module
+        saml2_config_dict = self._default_saml_config_dict(
+            *self.saml2_user_mapping_provider_class.get_saml_attributes(
+                self.saml2_user_mapping_provider_config
+            )
+        )
         _dict_merge(
             merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
         )
@@ -103,22 +159,27 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "5m")
         )
 
-        mapping = saml2_config.get("mxid_mapping", "hexencode")
-        try:
-            self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
-        except KeyError:
-            raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
-    def _default_saml_config_dict(self):
+    def _default_saml_config_dict(
+        self, required_attributes: set, optional_attributes: set
+    ):
+        """Generate a configuration dictionary with required and optional attributes that
+        will be needed to process new user registration
+
+        Args:
+            required_attributes: SAML auth response attributes that are
+                necessary to function
+            optional_attributes: SAML auth response attributes that can be used to add
+                additional information to Synapse user accounts, but are not required
+
+        Returns:
+            dict: A SAML configuration dictionary
+        """
         import saml2
 
         public_baseurl = self.public_baseurl
         if public_baseurl is None:
             raise ConfigError("saml2_config requires a public_baseurl to be set")
 
-        required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
-        optional_attributes = {"displayName"}
         if self.saml2_grandfathered_mxid_source_attribute:
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
         optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
           #
           #config_path: "%(config_dir_path)s/sp_conf.py"
 
-          # the lifetime of a SAML session. This defines how long a user has to
+          # The lifetime of a SAML session. This defines how long a user has to
           # complete the authentication process, if allow_unsolicited is unset.
           # The default is 5 minutes.
           #
           #saml_session_lifetime: 5m
 
-          # The SAML attribute (after mapping via the attribute maps) to use to derive
-          # the Matrix ID from. 'uid' by default.
+          # An external module can be provided here as a custom solution to
+          # mapping attributes returned from a saml provider onto a matrix user.
           #
-          #mxid_source_attribute: displayName
-
-          # The mapping system to use for mapping the saml attribute onto a matrix ID.
-          # Options include:
-          #  * 'hexencode' (which maps unpermitted characters to '=xx')
-          #  * 'dotreplace' (which replaces unpermitted characters with '.').
-          # The default is 'hexencode'.
-          #
-          #mxid_mapping: dotreplace
-
-          # In previous versions of synapse, the mapping from SAML attribute to MXID was
-          # always calculated dynamically rather than stored in a table. For backwards-
-          # compatibility, we will look for user_ids matching such a pattern before
-          # creating a new account.
+          user_mapping_provider:
+            # The custom module's class. Uncomment to use a custom module.
+            #
+            #module: mapping_provider.SamlMappingProvider
+
+            # Custom configuration values for the module. Below options are
+            # intended for the built-in provider, they should be changed if
+            # using a custom module. This section will be passed as a Python
+            # dictionary to the module's `parse_config` method.
+            #
+            config:
+              # The SAML attribute (after mapping via the attribute maps) to use
+              # to derive the Matrix ID from. 'uid' by default.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_source_attribute option. If that is still
+              # defined, its value will be used instead.
+              #
+              #mxid_source_attribute: displayName
+
+              # The mapping system to use for mapping the saml attribute onto a
+              # matrix ID.
+              #
+              # Options include:
+              #  * 'hexencode' (which maps unpermitted characters to '=xx')
+              #  * 'dotreplace' (which replaces unpermitted characters with
+              #     '.').
+              # The default is 'hexencode'.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_mapping option. If that is still defined, its
+              # value will be used instead.
+              #
+              #mxid_mapping: dotreplace
+
+          # In previous versions of synapse, the mapping from SAML attribute to
+          # MXID was always calculated dynamically rather than stored in a
+          # table. For backwards- compatibility, we will look for user_ids
+          # matching such a pattern before creating a new account.
           #
           # This setting controls the SAML attribute which will be used for this
-          # backwards-compatibility lookup. Typically it should be 'uid', but if the
-          # attribute maps are changed, it may be necessary to change it.
+          # backwards-compatibility lookup. Typically it should be 'uid', but if
+          # the attribute maps are changed, it may be necessary to change it.
           #
           # The default is 'uid'.
           #
@@ -241,23 +327,3 @@ class SAML2Config(Config):
         """ % {
             "config_dir_path": config_dir_path
         }
-
-
-DOT_REPLACE_PATTERN = re.compile(
-    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
-    username = username.lower()
-    username = DOT_REPLACE_PATTERN.sub(".", username)
-
-    # regular mxids aren't allowed to start with an underscore either
-    username = re.sub("^_", "", username)
-    return username
-
-
-MXID_MAPPER_MAP = {
-    "hexencode": map_username_to_mxid_localpart,
-    "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index a4bef00936..38f6ff9edc 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -102,6 +102,12 @@ class ServerConfig(Config):
             "require_auth_for_profile_requests", False
         )
 
+        # Whether to require sharing a room with a user to retrieve their
+        # profile data
+        self.limit_profile_requests_to_users_who_share_rooms = config.get(
+            "limit_profile_requests_to_users_who_share_rooms", False,
+        )
+
         if "restrict_public_rooms_to_local_users" in config and (
             "allow_public_rooms_without_auth" in config
             or "allow_public_rooms_over_federation" in config
@@ -200,7 +206,7 @@ class ServerConfig(Config):
         self.admin_contact = config.get("admin_contact", None)
 
         # FIXME: federation_domain_whitelist needs sytests
-        self.federation_domain_whitelist = None
+        self.federation_domain_whitelist = None  # type: Optional[dict]
         federation_domain_whitelist = config.get("federation_domain_whitelist", None)
 
         if federation_domain_whitelist is not None:
@@ -621,6 +627,13 @@ class ServerConfig(Config):
         #
         #require_auth_for_profile_requests: true
 
+        # Uncomment to require a user to share a room with another user in order
+        # to retrieve their profile information. Only checked on Client-Server
+        # requests. Profile requests from other servers should be checked by the
+        # requesting server. Defaults to 'false'.
+        #
+        #limit_profile_requests_to_users_who_share_rooms: true
+
         # If set to 'true', removes the need for authentication to access the server's
         # public rooms directory through the client API, meaning that anyone can
         # query the room directory. Defaults to 'false'.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 350ed9351f..1033e5e121 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -43,6 +43,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
     Returns:
          if the auth checks pass.
     """
+    assert isinstance(auth_events, dict)
+
     if do_size_check:
         _check_size_limits(event)
 
@@ -87,12 +89,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
-    if auth_events is None:
-        # Oh, we don't know what the state of the room was, so we
-        # are trusting that this is allowed (at least for now)
-        logger.warning("Trusting event: %s", event.event_id)
-        return
-
     if event.type == EventTypes.Create:
         sender_domain = get_domain_from_id(event.sender)
         room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 64e898f40c..a44baea365 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -149,7 +149,7 @@ class EventContext:
         # the prev_state_ids, so if we're a state event we include the event
         # id that we replaced in the state.
         if event.is_state():
-            prev_state_ids = yield self.get_prev_state_ids(store)
+            prev_state_ids = yield self.get_prev_state_ids()
             prev_state_id = prev_state_ids.get((event.type, event.state_key))
         else:
             prev_state_id = None
@@ -167,12 +167,13 @@ class EventContext:
         }
 
     @staticmethod
-    def deserialize(store, input):
+    def deserialize(storage, input):
         """Converts a dict that was produced by `serialize` back into a
         EventContext.
 
         Args:
-            store (DataStore): Used to convert AS ID to AS object
+            storage (Storage): Used to convert AS ID to AS object and fetch
+                state.
             input (dict): A dict produced by `serialize`
 
         Returns:
@@ -181,6 +182,7 @@ class EventContext:
         context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
+            storage=storage,
             prev_state_id=input["prev_state_id"],
             event_type=input["event_type"],
             event_state_key=input["event_state_key"],
@@ -193,7 +195,7 @@ class EventContext:
 
         app_service_id = input["app_service_id"]
         if app_service_id:
-            context.app_service = store.get_app_service_by_id(app_service_id)
+            context.app_service = storage.main.get_app_service_by_id(app_service_id)
 
         return context
 
@@ -216,7 +218,7 @@ class EventContext:
         return self._state_group
 
     @defer.inlineCallbacks
-    def get_current_state_ids(self, store):
+    def get_current_state_ids(self):
         """
         Gets the room state map, including this event - ie, the state in ``state_group``
 
@@ -234,11 +236,11 @@ class EventContext:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        yield self._ensure_fetched(store)
+        yield self._ensure_fetched()
         return self._current_state_ids
 
     @defer.inlineCallbacks
-    def get_prev_state_ids(self, store):
+    def get_prev_state_ids(self):
         """
         Gets the room state map, excluding this event.
 
@@ -250,7 +252,7 @@ class EventContext:
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-        yield self._ensure_fetched(store)
+        yield self._ensure_fetched()
         return self._prev_state_ids
 
     def get_cached_current_state_ids(self):
@@ -270,7 +272,7 @@ class EventContext:
 
         return self._current_state_ids
 
-    def _ensure_fetched(self, store):
+    def _ensure_fetched(self):
         return defer.succeed(None)
 
 
@@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext):
 
     Attributes:
 
+        _storage (Storage)
+
         _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
             been calculated. None if we haven't started calculating yet
 
@@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext):
             that was replaced.
     """
 
+    # This needs to have a default as we're inheriting
+    _storage = attr.ib(default=None)
     _prev_state_id = attr.ib(default=None)
     _event_type = attr.ib(default=None)
     _event_state_key = attr.ib(default=None)
     _fetching_state_deferred = attr.ib(default=None)
 
-    def _ensure_fetched(self, store):
+    def _ensure_fetched(self):
         if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store
-            )
+            self._fetching_state_deferred = run_in_background(self._fill_out_state)
 
         return make_deferred_yieldable(self._fetching_state_deferred)
 
     @defer.inlineCallbacks
-    def _fill_out_state(self, store):
+    def _fill_out_state(self):
         """Called to populate the _current_state_ids and _prev_state_ids
         attributes by loading from the database.
         """
         if self.state_group is None:
             return
 
-        self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
+        self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+            self.state_group
+        )
         if self._prev_state_id and self._event_state_key is not None:
             self._prev_state_ids = dict(self._current_state_ids)
 
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 714a9b1579..86f7e5f8aa 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -53,7 +53,7 @@ class ThirdPartyEventRules(object):
         if self.third_party_rules is None:
             return True
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         # Retrieve the state events from the database.
         state_events = {}
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d396e6564f..af652a7659 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -526,13 +526,7 @@ class FederationClient(FederationBase):
 
         @defer.inlineCallbacks
         def send_request(destination):
-            time_now = self._clock.time_msec()
-            _, content = yield self.transport_layer.send_join(
-                destination=destination,
-                room_id=pdu.room_id,
-                event_id=pdu.event_id,
-                content=pdu.get_pdu_json(time_now),
-            )
+            content = yield self._do_send_join(destination, pdu)
 
             logger.debug("Got content: %s", content)
 
@@ -600,6 +594,44 @@ class FederationClient(FederationBase):
         return self._try_destination_list("send_join", destinations, send_request)
 
     @defer.inlineCallbacks
+    def _do_send_join(self, destination, pdu):
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_join_v2(
+                destination=destination,
+                room_id=pdu.room_id,
+                event_id=pdu.event_id,
+                content=pdu.get_pdu_json(time_now),
+            )
+
+            return content
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                err = e.to_synapse_error()
+
+                # If we receive an error response that isn't a generic error, or an
+                # unrecognised endpoint error, we  assume that the remote understands
+                # the v2 invite API and this is a legitimate error.
+                if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+                    raise err
+            else:
+                raise e.to_synapse_error()
+
+        logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
+
+        resp = yield self.transport_layer.send_join_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+
+        # We expect the v1 API to respond with [200, content], so we only return the
+        # content.
+        return resp[1]
+
+    @defer.inlineCallbacks
     def send_invite(self, destination, room_id, event_id, pdu):
         room_version = yield self.store.get_room_version(room_id)
 
@@ -708,18 +740,50 @@ class FederationClient(FederationBase):
 
         @defer.inlineCallbacks
         def send_request(destination):
-            time_now = self._clock.time_msec()
-            _, content = yield self.transport_layer.send_leave(
+            content = yield self._do_send_leave(destination, pdu)
+
+            logger.debug("Got content: %s", content)
+            return None
+
+        return self._try_destination_list("send_leave", destinations, send_request)
+
+    @defer.inlineCallbacks
+    def _do_send_leave(self, destination, pdu):
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_leave_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
 
-            logger.debug("Got content: %s", content)
-            return None
+            return content
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                err = e.to_synapse_error()
 
-        return self._try_destination_list("send_leave", destinations, send_request)
+                # If we receive an error response that isn't a generic error, or an
+                # unrecognised endpoint error, we  assume that the remote understands
+                # the v2 invite API and this is a legitimate error.
+                if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+                    raise err
+            else:
+                raise e.to_synapse_error()
+
+        logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
+
+        resp = yield self.transport_layer.send_leave_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+
+        # We expect the v1 API to respond with [200, content], so we only return the
+        # content.
+        return resp[1]
 
     def get_public_rooms(
         self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 84d4eca041..d7ce333822 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -384,15 +384,10 @@ class FederationServer(FederationBase):
 
         res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
-        return (
-            200,
-            {
-                "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-                "auth_chain": [
-                    p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
-                ],
-            },
-        )
+        return {
+            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+            "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+        }
 
     async def on_make_leave_request(self, origin, room_id, user_id):
         origin_host, _ = parse_server_name(origin)
@@ -419,7 +414,7 @@ class FederationServer(FederationBase):
         pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         await self.handler.on_send_leave_request(origin, pdu)
-        return 200, {}
+        return {}
 
     async def on_event_auth(self, origin, room_id, event_id):
         with (await self._server_linearizer.queue((origin, room_id))):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 46dba84cac..198257414b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -243,7 +243,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_join(self, destination, room_id, event_id, content):
+    def send_join_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
@@ -254,7 +254,18 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_leave(self, destination, room_id, event_id, content):
+    def send_join_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination, path=path, data=content
+        )
+
+        return response
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_leave_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
@@ -272,6 +283,24 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def send_leave_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+            # we want to do our best to send this through. The problem is
+            # that if it fails, we won't retry it later, so if the remote
+            # server was just having a momentary blip, the room will be out of
+            # sync.
+            ignore_backoff=True,
+        )
+
+        return response
+
+    @defer.inlineCallbacks
+    @log_function
     def send_invite_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fefc789c85..b4cbf23394 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
         return 200, content
 
 
-class FederationSendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServlet):
     PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
     async def on_PUT(self, origin, content, query, room_id, event_id):
         content = await self.handler.on_send_leave_request(origin, content, room_id)
+        return 200, (200, content)
+
+
+class FederationV2SendLeaveServlet(BaseFederationServlet):
+    PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+    PREFIX = FEDERATION_V2_PREFIX
+
+    async def on_PUT(self, origin, content, query, room_id, event_id):
+        content = await self.handler.on_send_leave_request(origin, content, room_id)
         return 200, content
 
 
@@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
         return await self.handler.on_event_auth(origin, context, event_id)
 
 
-class FederationSendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServlet):
+    PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+    async def on_PUT(self, origin, content, query, context, event_id):
+        # TODO(paul): assert that context/event_id parsed from path actually
+        #   match those given in content
+        content = await self.handler.on_send_join_request(origin, content, context)
+        return 200, (200, content)
+
+
+class FederationV2SendJoinServlet(BaseFederationServlet):
     PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
+    PREFIX = FEDERATION_V2_PREFIX
+
     async def on_PUT(self, origin, content, query, context, event_id):
         # TODO(paul): assert that context/event_id parsed from path actually
         #   match those given in content
@@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
     FederationMakeJoinServlet,
     FederationMakeLeaveServlet,
     FederationEventServlet,
-    FederationSendJoinServlet,
-    FederationSendLeaveServlet,
+    FederationV1SendJoinServlet,
+    FederationV2SendJoinServlet,
+    FederationV1SendLeaveServlet,
+    FederationV2SendLeaveServlet,
     FederationV1InviteServlet,
     FederationV2InviteServlet,
     FederationQueryAuthServlet,
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 29e8ffc295..0ec9be3cb5 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -773,6 +773,11 @@ class GroupsServerHandler(object):
         if not self.hs.is_mine_id(user_id):
             yield self.store.maybe_delete_remote_profile_cache(user_id)
 
+        # Delete group if the last user has left
+        users = yield self.store.get_users_in_group(group_id, include_private=True)
+        if not users:
+            yield self.store.delete_group(group_id)
+
         return {}
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d15c6282fb..51413d910e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -134,7 +134,7 @@ class BaseHandler(object):
             guest_access = event.content.get("guest_access", "forbidden")
             if guest_access != "can_join":
                 if context:
-                    current_state_ids = yield context.get_current_state_ids(self.store)
+                    current_state_ids = yield context.get_current_state_ids()
                     current_state = yield self.store.get_events(
                         list(current_state_ids.values())
                     )
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..a8d3fbc6de 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 
 class AccountDataEventSource(object):
     def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
     def get_current_key(self, direction="f"):
         return self.store.get_max_account_data_stream_id()
 
-    @defer.inlineCallbacks
-    def get_new_events(self, user, from_key, **kwargs):
+    async def get_new_events(self, user, from_key, **kwargs):
         user_id = user.to_string()
         last_stream_id = from_key
 
-        current_stream_id = yield self.store.get_max_account_data_stream_id()
+        current_stream_id = self.store.get_max_account_data_stream_id()
 
         results = []
-        tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+        tags = await self.store.get_updated_tags(user_id, last_stream_id)
 
         for room_id, room_tags in tags.items():
             results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
         (
             account_data,
             room_account_data,
-        ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+        ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
 
         for account_data_type, content in account_data.items():
             results.append({"type": account_data_type, "content": content})
@@ -53,7 +50,3 @@ class AccountDataEventSource(object):
                 )
 
         return results, current_stream_id
-
-    @defer.inlineCallbacks
-    def get_pagination_rows(self, user, config, key):
-        return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
 import logging
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
 
 from synapse.api.errors import StoreError
 from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
                 # run as a background process to make sure that the database transactions
                 # have a logcontext to report to
                 return run_as_background_process(
-                    "send_renewals", self.send_renewal_emails
+                    "send_renewals", self._send_renewal_emails
                 )
 
             self.clock.looping_call(send_emails, 30 * 60 * 1000)
 
-    @defer.inlineCallbacks
-    def send_renewal_emails(self):
+    async def _send_renewal_emails(self):
         """Gets the list of users whose account is expiring in the amount of time
         configured in the ``renew_at`` parameter from the ``account_validity``
         configuration, and sends renewal emails to all of these users as long as they
         have an email 3PID attached to their account.
         """
-        expiring_users = yield self.store.get_users_expiring_soon()
+        expiring_users = await self.store.get_users_expiring_soon()
 
         if expiring_users:
             for user in expiring_users:
-                yield self._send_renewal_email(
+                await self._send_renewal_email(
                     user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
                 )
 
-    @defer.inlineCallbacks
-    def send_renewal_email_to_user(self, user_id):
-        expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
-        yield self._send_renewal_email(user_id, expiration_ts)
+    async def send_renewal_email_to_user(self, user_id: str):
+        expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+        await self._send_renewal_email(user_id, expiration_ts)
 
-    @defer.inlineCallbacks
-    def _send_renewal_email(self, user_id, expiration_ts):
+    async def _send_renewal_email(self, user_id: str, expiration_ts: int):
         """Sends out a renewal email to every email address attached to the given user
         with a unique link allowing them to renew their account.
 
         Args:
-            user_id (str): ID of the user to send email(s) to.
-            expiration_ts (int): Timestamp in milliseconds for the expiration date of
+            user_id: ID of the user to send email(s) to.
+            expiration_ts: Timestamp in milliseconds for the expiration date of
                 this user's account (used in the email templates).
         """
-        addresses = yield self._get_email_addresses_for_user(user_id)
+        addresses = await self._get_email_addresses_for_user(user_id)
 
         # Stop right here if the user doesn't have at least one email address.
         # In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
             return
 
         try:
-            user_display_name = yield self.store.get_profile_displayname(
+            user_display_name = await self.store.get_profile_displayname(
                 UserID.from_string(user_id).localpart
             )
             if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
         except StoreError:
             user_display_name = user_id
 
-        renewal_token = yield self._get_renewal_token(user_id)
+        renewal_token = await self._get_renewal_token(user_id)
         url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
             self.hs.config.public_baseurl,
             renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
 
             logger.info("Sending renewal email to %s", address)
 
-            yield make_deferred_yieldable(
+            await make_deferred_yieldable(
                 self.sendmail(
                     self.hs.config.email_smtp_host,
                     self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
                 )
             )
 
-        yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+        await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
 
-    @defer.inlineCallbacks
-    def _get_email_addresses_for_user(self, user_id):
+    async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
         """Retrieve the list of email addresses attached to a user's account.
 
         Args:
-            user_id (str): ID of the user to lookup email addresses for.
+            user_id: ID of the user to lookup email addresses for.
 
         Returns:
-            defer.Deferred[list[str]]: Email addresses for this account.
+            Email addresses for this account.
         """
-        threepids = yield self.store.user_get_threepids(user_id)
+        threepids = await self.store.user_get_threepids(user_id)
 
         addresses = []
         for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
 
         return addresses
 
-    @defer.inlineCallbacks
-    def _get_renewal_token(self, user_id):
+    async def _get_renewal_token(self, user_id: str) -> str:
         """Generates a 32-byte long random string that will be inserted into the
         user's renewal email's unique link, then saves it into the database.
 
         Args:
-            user_id (str): ID of the user to generate a string for.
+            user_id: ID of the user to generate a string for.
 
         Returns:
-            defer.Deferred[str]: The generated string.
+            The generated string.
 
         Raises:
             StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
         while attempts < 5:
             try:
                 renewal_token = stringutils.random_string(32)
-                yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+                await self.store.set_renewal_token_for_user(user_id, renewal_token)
                 return renewal_token
             except StoreError:
                 attempts += 1
         raise StoreError(500, "Couldn't generate a unique string as refresh string.")
 
-    @defer.inlineCallbacks
-    def renew_account(self, renewal_token):
+    async def renew_account(self, renewal_token: str) -> bool:
         """Renews the account attached to a given renewal token by pushing back the
         expiration date by the current validity period in the server's configuration.
 
         Args:
-            renewal_token (str): Token sent with the renewal request.
+            renewal_token: Token sent with the renewal request.
         Returns:
-            bool: Whether the provided token is valid.
+            Whether the provided token is valid.
         """
         try:
-            user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+            user_id = await self.store.get_user_from_renewal_token(renewal_token)
         except StoreError:
-            defer.returnValue(False)
+            return False
 
         logger.debug("Renewing an account for user %s", user_id)
-        yield self.renew_account_for_user(user_id)
+        await self.renew_account_for_user(user_id)
 
-        defer.returnValue(True)
+        return True
 
-    @defer.inlineCallbacks
-    def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+    async def renew_account_for_user(
+        self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+    ) -> int:
         """Renews the account attached to a given user by pushing back the
         expiration date by the current validity period in the server's
         configuration.
 
         Args:
-            renewal_token (str): Token sent with the renewal request.
-            expiration_ts (int): New expiration date. Defaults to now + validity period.
-            email_sent (bool): Whether an email has been sent for this validity period.
+            renewal_token: Token sent with the renewal request.
+            expiration_ts: New expiration date. Defaults to now + validity period.
+            email_sen: Whether an email has been sent for this validity period.
                 Defaults to False.
 
         Returns:
-            defer.Deferred[int]: New expiration date for this account, as a timestamp
-                in milliseconds since epoch.
+            New expiration date for this account, as a timestamp in
+            milliseconds since epoch.
         """
         if expiration_ts is None:
             expiration_ts = self.clock.time_msec() + self._account_validity.period
 
-        yield self.store.set_account_validity_for_user(
+        await self.store.set_account_validity_for_user(
             user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
         )
 
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 14449b9a1e..1a4ba12385 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import Membership
 from synapse.types import RoomStreamToken
 from synapse.visibility import filter_events_for_client
@@ -33,11 +31,10 @@ class AdminHandler(BaseHandler):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    @defer.inlineCallbacks
-    def get_whois(self, user):
+    async def get_whois(self, user):
         connections = []
 
-        sessions = yield self.store.get_user_ip_and_agents(user)
+        sessions = await self.store.get_user_ip_and_agents(user)
         for session in sessions:
             connections.append(
                 {
@@ -54,20 +51,18 @@ class AdminHandler(BaseHandler):
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_users(self):
+    async def get_users(self):
         """Function to retrieve a list of users in users table.
 
         Args:
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        ret = yield self.store.get_users()
+        ret = await self.store.get_users()
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_users_paginate(self, start, limit, name, guests, deactivated):
+    async def get_users_paginate(self, start, limit, name, guests, deactivated):
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users.
 
@@ -80,14 +75,13 @@ class AdminHandler(BaseHandler):
         Returns:
             defer.Deferred: resolves to json list[dict[str, Any]]
         """
-        ret = yield self.store.get_users_paginate(
+        ret = await self.store.get_users_paginate(
             start, limit, name, guests, deactivated
         )
 
         return ret
 
-    @defer.inlineCallbacks
-    def search_users(self, term):
+    async def search_users(self, term):
         """Function to search users list for one or more users with
         the matched term.
 
@@ -96,7 +90,7 @@ class AdminHandler(BaseHandler):
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        ret = yield self.store.search_users(term)
+        ret = await self.store.search_users(term)
 
         return ret
 
@@ -119,8 +113,7 @@ class AdminHandler(BaseHandler):
         """
         return self.store.set_server_admin(user, admin)
 
-    @defer.inlineCallbacks
-    def export_user_data(self, user_id, writer):
+    async def export_user_data(self, user_id, writer):
         """Write all data we have on the user to the given writer.
 
         Args:
@@ -132,7 +125,7 @@ class AdminHandler(BaseHandler):
             The returned value is that returned by `writer.finished()`.
         """
         # Get all rooms the user is in or has been in
-        rooms = yield self.store.get_rooms_for_user_where_membership_is(
+        rooms = await self.store.get_rooms_for_user_where_membership_is(
             user_id,
             membership_list=(
                 Membership.JOIN,
@@ -145,7 +138,7 @@ class AdminHandler(BaseHandler):
         # We only try and fetch events for rooms the user has been in. If
         # they've been e.g. invited to a room without joining then we handle
         # those seperately.
-        rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id)
+        rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
 
         for index, room in enumerate(rooms):
             room_id = room.room_id
@@ -154,7 +147,7 @@ class AdminHandler(BaseHandler):
                 "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
             )
 
-            forgotten = yield self.store.did_forget(user_id, room_id)
+            forgotten = await self.store.did_forget(user_id, room_id)
             if forgotten:
                 logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
                 continue
@@ -166,7 +159,7 @@ class AdminHandler(BaseHandler):
 
                 if room.membership == Membership.INVITE:
                     event_id = room.event_id
-                    invite = yield self.store.get_event(event_id, allow_none=True)
+                    invite = await self.store.get_event(event_id, allow_none=True)
                     if invite:
                         invited_state = invite.unsigned["invite_room_state"]
                         writer.write_invite(room_id, invite, invited_state)
@@ -177,7 +170,7 @@ class AdminHandler(BaseHandler):
             # were joined. We estimate that point by looking at the
             # stream_ordering of the last membership if it wasn't a join.
             if room.membership == Membership.JOIN:
-                stream_ordering = yield self.store.get_room_max_stream_ordering()
+                stream_ordering = self.store.get_room_max_stream_ordering()
             else:
                 stream_ordering = room.stream_ordering
 
@@ -203,7 +196,7 @@ class AdminHandler(BaseHandler):
             # events that we have and then filtering, this isn't the most
             # efficient method perhaps but it does guarantee we get everything.
             while True:
-                events, _ = yield self.store.paginate_room_events(
+                events, _ = await self.store.paginate_room_events(
                     room_id, from_key, to_key, limit=100, direction="f"
                 )
                 if not events:
@@ -211,7 +204,7 @@ class AdminHandler(BaseHandler):
 
                 from_key = events[-1].internal_metadata.after
 
-                events = yield filter_events_for_client(self.storage, user_id, events)
+                events = await filter_events_for_client(self.storage, user_id, events)
 
                 writer.write_events(room_id, events)
 
@@ -247,7 +240,7 @@ class AdminHandler(BaseHandler):
             for event_id in extremities:
                 if not event_to_unseen_prevs[event_id]:
                     continue
-                state = yield self.state_store.get_state_for_event(event_id)
+                state = await self.state_store.get_state_for_event(event_id)
                 writer.write_state(room_id, event_id, state)
 
         return writer.finished()
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 6dedaaff8d..4426967f88 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -15,8 +15,6 @@
 # limitations under the License.
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import UserID, create_requester
@@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler):
 
         self._account_validity_enabled = hs.config.account_validity.enabled
 
-    @defer.inlineCallbacks
-    def deactivate_account(self, user_id, erase_data, id_server=None):
+    async def deactivate_account(self, user_id, erase_data, id_server=None):
         """Deactivate a user's account
 
         Args:
@@ -74,11 +71,11 @@ class DeactivateAccountHandler(BaseHandler):
         identity_server_supports_unbinding = True
 
         # Retrieve the 3PIDs this user has bound to an identity server
-        threepids = yield self.store.user_get_bound_threepids(user_id)
+        threepids = await self.store.user_get_bound_threepids(user_id)
 
         for threepid in threepids:
             try:
-                result = yield self._identity_handler.try_unbind_threepid(
+                result = await self._identity_handler.try_unbind_threepid(
                     user_id,
                     {
                         "medium": threepid["medium"],
@@ -91,33 +88,33 @@ class DeactivateAccountHandler(BaseHandler):
                 # Do we want this to be a fatal error or should we carry on?
                 logger.exception("Failed to remove threepid from ID server")
                 raise SynapseError(400, "Failed to remove threepid from ID server")
-            yield self.store.user_delete_threepid(
+            await self.store.user_delete_threepid(
                 user_id, threepid["medium"], threepid["address"]
             )
 
         # Remove all 3PIDs this user has bound to the homeserver
-        yield self.store.user_delete_threepids(user_id)
+        await self.store.user_delete_threepids(user_id)
 
         # delete any devices belonging to the user, which will also
         # delete corresponding access tokens.
-        yield self._device_handler.delete_all_devices_for_user(user_id)
+        await self._device_handler.delete_all_devices_for_user(user_id)
         # then delete any remaining access tokens which weren't associated with
         # a device.
-        yield self._auth_handler.delete_access_tokens_for_user(user_id)
+        await self._auth_handler.delete_access_tokens_for_user(user_id)
 
-        yield self.store.user_set_password_hash(user_id, None)
+        await self.store.user_set_password_hash(user_id, None)
 
         # Add the user to a table of users pending deactivation (ie.
         # removal from all the rooms they're a member of)
-        yield self.store.add_user_pending_deactivation(user_id)
+        await self.store.add_user_pending_deactivation(user_id)
 
         # delete from user directory
-        yield self.user_directory_handler.handle_user_deactivated(user_id)
+        await self.user_directory_handler.handle_user_deactivated(user_id)
 
         # Mark the user as erased, if they asked for that
         if erase_data:
             logger.info("Marking %s as erased", user_id)
-            yield self.store.mark_user_erased(user_id)
+            await self.store.mark_user_erased(user_id)
 
         # Now start the process that goes through that list and
         # parts users from rooms (if it isn't already running)
@@ -125,30 +122,29 @@ class DeactivateAccountHandler(BaseHandler):
 
         # Reject all pending invites for the user, so that the user doesn't show up in the
         # "invited" section of rooms' members list.
-        yield self._reject_pending_invites_for_user(user_id)
+        await self._reject_pending_invites_for_user(user_id)
 
         # Remove all information on the user from the account_validity table.
         if self._account_validity_enabled:
-            yield self.store.delete_account_validity_for_user(user_id)
+            await self.store.delete_account_validity_for_user(user_id)
 
         # Mark the user as deactivated.
-        yield self.store.set_user_deactivated_status(user_id, True)
+        await self.store.set_user_deactivated_status(user_id, True)
 
         return identity_server_supports_unbinding
 
-    @defer.inlineCallbacks
-    def _reject_pending_invites_for_user(self, user_id):
+    async def _reject_pending_invites_for_user(self, user_id):
         """Reject pending invites addressed to a given user ID.
 
         Args:
             user_id (str): The user ID to reject pending invites for.
         """
         user = UserID.from_string(user_id)
-        pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
+        pending_invites = await self.store.get_invited_rooms_for_user(user_id)
 
         for room in pending_invites:
             try:
-                yield self._room_member_handler.update_membership(
+                await self._room_member_handler.update_membership(
                     create_requester(user),
                     user,
                     room.room_id,
@@ -180,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler):
         if not self._user_parter_running:
             run_as_background_process("user_parter_loop", self._user_parter_loop)
 
-    @defer.inlineCallbacks
-    def _user_parter_loop(self):
+    async def _user_parter_loop(self):
         """Loop that parts deactivated users from rooms
 
         Returns:
@@ -191,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler):
         logger.info("Starting user parter")
         try:
             while True:
-                user_id = yield self.store.get_user_pending_deactivation()
+                user_id = await self.store.get_user_pending_deactivation()
                 if user_id is None:
                     break
                 logger.info("User parter parting %r", user_id)
-                yield self._part_user(user_id)
-                yield self.store.del_user_pending_deactivation(user_id)
+                await self._part_user(user_id)
+                await self.store.del_user_pending_deactivation(user_id)
                 logger.info("User parter finished parting %r", user_id)
             logger.info("User parter finished: stopping")
         finally:
             self._user_parter_running = False
 
-    @defer.inlineCallbacks
-    def _part_user(self, user_id):
+    async def _part_user(self, user_id):
         """Causes the given user_id to leave all the rooms they're joined to
 
         Returns:
@@ -211,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler):
         """
         user = UserID.from_string(user_id)
 
-        rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+        rooms_for_user = await self.store.get_rooms_for_user(user_id)
         for room_id in rooms_for_user:
             logger.info("User parter parting %r from %r", user_id, room_id)
             try:
-                yield self._room_member_handler.update_membership(
+                await self._room_member_handler.update_membership(
                     create_requester(user),
                     user,
                     room_id,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 57a10daefd..2d889364d4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,6 +264,7 @@ class E2eKeysHandler(object):
 
         return ret
 
+    @defer.inlineCallbacks
     def get_cross_signing_keys_from_cache(self, query, from_user_id):
         """Get cross-signing keys for users from the database
 
@@ -283,14 +284,32 @@ class E2eKeysHandler(object):
         self_signing_keys = {}
         user_signing_keys = {}
 
-        # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
-        return defer.succeed(
-            {
-                "master_keys": master_keys,
-                "self_signing_keys": self_signing_keys,
-                "user_signing_keys": user_signing_keys,
-            }
-        )
+        user_ids = list(query)
+
+        keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+
+        for user_id, user_info in keys.items():
+            if user_info is None:
+                continue
+            if "master" in user_info:
+                master_keys[user_id] = user_info["master"]
+            if "self_signing" in user_info:
+                self_signing_keys[user_id] = user_info["self_signing"]
+
+        if (
+            from_user_id in keys
+            and keys[from_user_id] is not None
+            and "user_signing" in keys[from_user_id]
+        ):
+            # users can see other users' master and self-signing keys, but can
+            # only see their own user-signing keys
+            user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+
+        return {
+            "master_keys": master_keys,
+            "self_signing_keys": self_signing_keys,
+            "user_signing_keys": user_signing_keys,
+        }
 
     @trace
     @defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6fb453ce60..72a0febc2b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
 
 import itertools
 import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
 
 import six
 from six import iteritems, itervalues
@@ -63,6 +63,7 @@ from synapse.replication.http.federation import (
 )
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
@@ -163,8 +164,7 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -174,17 +174,15 @@ class FederationHandler(BaseHandler):
             pdu (FrozenEvent): received PDU
             sent_to_us_directly (bool): True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
-
-        Returns (Deferred): completes with None
         """
 
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+        logger.info("handling received PDU: %s", pdu)
 
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self.store.get_event(
+        existing = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
@@ -228,7 +226,7 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
                 "[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -243,12 +241,12 @@ class FederationHandler(BaseHandler):
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
-            min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+            min_depth = await self.get_min_depth_for_context(pdu.room_id)
 
             logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
 
             prevs = set(pdu.prev_event_ids())
-            seen = yield self.store.have_seen_events(prevs)
+            seen = await self.store.have_seen_events(prevs)
 
             if min_depth and pdu.depth < min_depth:
                 # This is so that we don't notify the user about this
@@ -268,7 +266,7 @@ class FederationHandler(BaseHandler):
                         len(missing_prevs),
                         shortstr(missing_prevs),
                     )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                    with (await self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
                             "[%s %s] Acquired room lock to fetch %d missing prev_events",
                             room_id,
@@ -276,13 +274,19 @@ class FederationHandler(BaseHandler):
                             len(missing_prevs),
                         )
 
-                        yield self._get_missing_events_for_pdu(
-                            origin, pdu, prevs, min_depth
-                        )
+                        try:
+                            await self._get_missing_events_for_pdu(
+                                origin, pdu, prevs, min_depth
+                            )
+                        except Exception as e:
+                            raise Exception(
+                                "Error fetching missing prev_events for %s: %s"
+                                % (event_id, e)
+                            )
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
-                        seen = yield self.store.have_seen_events(prevs)
+                        seen = await self.store.have_seen_events(prevs)
 
                         if not prevs - seen:
                             logger.info(
@@ -290,14 +294,6 @@ class FederationHandler(BaseHandler):
                                 room_id,
                                 event_id,
                             )
-                elif missing_prevs:
-                    logger.info(
-                        "[%s %s] Not recursively fetching %d missing prev_events: %s",
-                        room_id,
-                        event_id,
-                        len(missing_prevs),
-                        shortstr(missing_prevs),
-                    )
 
             if prevs - seen:
                 # We've still not been able to get all of the prev_events for this event.
@@ -342,12 +338,18 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
+                logger.info(
+                    "Event %s is missing prev_events: calculating state for a "
+                    "backwards extremity",
+                    event_id,
+                )
+
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+                    ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
                     state_maps = list(
@@ -361,17 +363,14 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         logger.info(
-                            "[%s %s] Requesting state at missing prev_event %s",
-                            room_id,
-                            event_id,
-                            p,
+                            "Requesting state at missing prev_event %s", event_id,
                         )
 
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
                             # auth events, the requests to fetch those events are deduped
                             # by the get_pdu_cache in federation_client.
-                            (remote_state, _,) = yield self._get_state_for_room(
+                            (remote_state, _,) = await self._get_state_for_room(
                                 origin, room_id, p, include_event_in_state=True
                             )
 
@@ -383,8 +382,8 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
-                    room_version = yield self.store.get_room_version(room_id)
-                    state_map = yield resolve_events_with_store(
+                    room_version = await self.store.get_room_version(room_id)
+                    state_map = await resolve_events_with_store(
                         room_id,
                         room_version,
                         state_maps,
@@ -397,10 +396,10 @@ class FederationHandler(BaseHandler):
 
                     # First though we need to fetch all the events that are in
                     # state_map, so we can build up the state below.
-                    evs = yield self.store.get_events(
+                    evs = await self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
-                        check_redacted=False,
+                        redact_behaviour=EventRedactBehaviour.AS_IS,
                     )
                     event_map.update(evs)
 
@@ -420,10 +419,9 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        yield self._process_received_pdu(origin, pdu, state=state)
+        await self._process_received_pdu(origin, pdu, state=state)
 
-    @defer.inlineCallbacks
-    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
         """
         Args:
             origin (str): Origin of the pdu. Will be called to get the missing events
@@ -435,12 +433,12 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        seen = yield self.store.have_seen_events(prevs)
+        seen = await self.store.have_seen_events(prevs)
 
         if not prevs - seen:
             return
 
-        latest = yield self.store.get_latest_event_ids_in_room(room_id)
+        latest = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
@@ -504,7 +502,7 @@ class FederationHandler(BaseHandler):
         # All that said: Let's try increasing the timout to 60s and see what happens.
 
         try:
-            missing_events = yield self.federation_client.get_missing_events(
+            missing_events = await self.federation_client.get_missing_events(
                 origin,
                 room_id,
                 earliest_events_ids=list(latest),
@@ -543,7 +541,7 @@ class FederationHandler(BaseHandler):
             )
             with nested_logging_context(ev.event_id):
                 try:
-                    yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
                 except FederationError as e:
                     if e.code == 403:
                         logger.warning(
@@ -555,29 +553,30 @@ class FederationHandler(BaseHandler):
                     else:
                         raise
 
-    @defer.inlineCallbacks
-    @log_function
-    def _get_state_for_room(
-        self, destination, room_id, event_id, include_event_in_state
-    ):
+    async def _get_state_for_room(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        include_event_in_state: bool = False,
+    ) -> Tuple[List[EventBase], List[EventBase]]:
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+            destination: The remote homeserver to query for the state.
+            room_id: The id of the room we're interested in.
+            event_id: The id of the event we want the state at.
             include_event_in_state: if true, the event itself will be included in the
                 returned state event list.
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            A list of events in the state, possibly including the event itself, and
+            a list of events in the auth chain for the given event.
         """
         (
             state_event_ids,
             auth_event_ids,
-        ) = yield self.federation_client.get_room_state_ids(
+        ) = await self.federation_client.get_room_state_ids(
             destination, room_id, event_id=event_id
         )
 
@@ -586,15 +585,15 @@ class FederationHandler(BaseHandler):
         if include_event_in_state:
             desired_events.add(event_id)
 
-        event_map = yield self._get_events_from_store_or_dest(
+        event_map = await self._get_events_from_store_or_dest(
             destination, room_id, desired_events
         )
 
         failed_to_fetch = desired_events - event_map.keys()
         if failed_to_fetch:
             logger.warning(
-                "Failed to fetch missing state/auth events for %s: %s",
-                room_id,
+                "Failed to fetch missing state/auth events for %s %s",
+                event_id,
                 failed_to_fetch,
             )
 
@@ -614,15 +613,11 @@ class FederationHandler(BaseHandler):
 
         return remote_state, auth_chain
 
-    @defer.inlineCallbacks
-    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+    async def _get_events_from_store_or_dest(
+        self, destination: str, room_id: str, event_ids: Iterable[str]
+    ) -> Dict[str, EventBase]:
         """Fetch events from a remote destination, checking if we already have them.
 
-        Args:
-            destination (str)
-            room_id (str)
-            event_ids (Iterable[str])
-
         Persists any events we don't already have as outliers.
 
         If we fail to fetch any of the events, a warning will be logged, and the event
@@ -630,10 +625,9 @@ class FederationHandler(BaseHandler):
         be in the given room.
 
         Returns:
-            Deferred[dict[str, EventBase]]: A deferred resolving to a map
-            from event_id to event
+            map from event_id to event
         """
-        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+        fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
 
         missing_events = set(event_ids) - fetched_events.keys()
 
@@ -644,14 +638,14 @@ class FederationHandler(BaseHandler):
                 room_id,
             )
 
-            yield self._get_events_and_persist(
+            await self._get_events_and_persist(
                 destination=destination, room_id=room_id, events=missing_events
             )
 
             # we need to make sure we re-load from the database to get the rejected
             # state correct.
             fetched_events.update(
-                (yield self.store.get_events(missing_events, allow_rejected=True))
+                (await self.store.get_events(missing_events, allow_rejected=True))
             )
 
         # check for events which were in the wrong room.
@@ -677,12 +671,14 @@ class FederationHandler(BaseHandler):
                 bad_room_id,
                 room_id,
             )
+
             del fetched_events[bad_event_id]
 
         return fetched_events
 
-    @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, event, state):
+    async def _process_received_pdu(
+        self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+    ):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
 
@@ -701,15 +697,15 @@ class FederationHandler(BaseHandler):
         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
-            context = yield self._handle_new_event(origin, event, state=state)
+            context = await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        room = yield self.store.get_room(room_id)
+        room = await self.store.get_room(room_id)
 
         if not room:
             try:
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=room_id, room_creator_user_id="", is_public=False
                 )
             except StoreError:
@@ -722,11 +718,11 @@ class FederationHandler(BaseHandler):
                 # changing their profile info.
                 newly_joined = True
 
-                prev_state_ids = yield context.get_prev_state_ids(self.store)
+                prev_state_ids = await context.get_prev_state_ids()
 
                 prev_state_id = prev_state_ids.get((event.type, event.state_key))
                 if prev_state_id:
-                    prev_state = yield self.store.get_event(
+                    prev_state = await self.store.get_event(
                         prev_state_id, allow_none=True
                     )
                     if prev_state and prev_state.membership == Membership.JOIN:
@@ -734,11 +730,10 @@ class FederationHandler(BaseHandler):
 
                 if newly_joined:
                     user = UserID.from_string(event.state_key)
-                    yield self.user_joined_room(user, room_id)
+                    await self.user_joined_room(user, room_id)
 
     @log_function
-    @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(self, dest, room_id, limit, extremities):
         """ Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -755,7 +750,7 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        events = yield self.federation_client.backfill(
+        events = await self.federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
 
@@ -770,7 +765,7 @@ class FederationHandler(BaseHandler):
         #     self._sanity_check_event(ev)
 
         # Don't bother processing events we already have.
-        seen_events = yield self.store.have_events_in_timeline(
+        seen_events = await self.store.have_events_in_timeline(
             set(e.event_id for e in events)
         )
 
@@ -796,7 +791,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self._get_state_for_room(
+            state, auth = await self._get_state_for_room(
                 destination=dest,
                 room_id=room_id,
                 event_id=e_id,
@@ -843,7 +838,7 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        yield self._handle_new_events(dest, ev_infos, backfilled=True)
+        await self._handle_new_events(dest, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -859,16 +854,15 @@ class FederationHandler(BaseHandler):
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            yield self._handle_new_event(dest, event, backfilled=True)
+            await self._handle_new_event(dest, event, backfilled=True)
 
         return events
 
-    @defer.inlineCallbacks
-    def maybe_backfill(self, room_id, current_depth):
+    async def maybe_backfill(self, room_id, current_depth):
         """Checks the database to see if we should backfill before paginating,
         and if so do.
         """
-        extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+        extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
 
         if not extremities:
             logger.debug("Not backfilling as no extremeties found.")
@@ -900,15 +894,17 @@ class FederationHandler(BaseHandler):
         #   state *before* the event, ignoring the special casing certain event
         #   types have.
 
-        forward_events = yield self.store.get_successor_events(list(extremities))
+        forward_events = await self.store.get_successor_events(list(extremities))
 
-        extremities_events = yield self.store.get_events(
-            forward_events, check_redacted=False, get_prev_content=False
+        extremities_events = await self.store.get_events(
+            forward_events,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+            get_prev_content=False,
         )
 
         # We set `check_history_visibility_only` as we might otherwise get false
         # positives from users having been erased.
-        filtered_extremities = yield filter_events_for_server(
+        filtered_extremities = await filter_events_for_server(
             self.storage,
             self.server_name,
             list(extremities_events.values()),
@@ -938,7 +934,7 @@ class FederationHandler(BaseHandler):
         # First we try hosts that are already in the room
         # TODO: HEURISTIC ALERT.
 
-        curr_state = yield self.state_handler.get_current_state(room_id)
+        curr_state = await self.state_handler.get_current_state(room_id)
 
         def get_domains_from_state(state):
             """Get joined domains from state
@@ -977,12 +973,11 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        @defer.inlineCallbacks
-        def try_backfill(domains):
+        async def try_backfill(domains):
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
-                    yield self.backfill(
+                    await self.backfill(
                         dom, room_id, limit=100, extremities=extremities
                     )
                     # If this succeeded then we probably already have the
@@ -1013,7 +1008,7 @@ class FederationHandler(BaseHandler):
 
             return False
 
-        success = yield try_backfill(likely_domains)
+        success = await try_backfill(likely_domains)
         if success:
             return True
 
@@ -1027,7 +1022,7 @@ class FederationHandler(BaseHandler):
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
-        states = yield make_deferred_yieldable(
+        states = await make_deferred_yieldable(
             defer.gatherResults(
                 [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
             )
@@ -1037,7 +1032,7 @@ class FederationHandler(BaseHandler):
         # event_ids.
         states = dict(zip(event_ids, [s.state for s in states]))
 
-        state_map = yield self.store.get_events(
+        state_map = await self.store.get_events(
             [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
             get_prev_content=False,
         )
@@ -1053,7 +1048,7 @@ class FederationHandler(BaseHandler):
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
-            success = yield try_backfill(
+            success = await try_backfill(
                 [dom for dom, _ in likely_domains if dom not in tried_domains]
             )
             if success:
@@ -1063,8 +1058,7 @@ class FederationHandler(BaseHandler):
 
         return False
 
-    @defer.inlineCallbacks
-    def _get_events_and_persist(
+    async def _get_events_and_persist(
         self, destination: str, room_id: str, events: Iterable[str]
     ):
         """Fetch the given events from a server, and persist them as outliers.
@@ -1072,7 +1066,7 @@ class FederationHandler(BaseHandler):
         Logs a warning if we can't find the given event.
         """
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version(room_id)
 
         event_infos = []
 
@@ -1108,9 +1102,9 @@ class FederationHandler(BaseHandler):
                         e,
                     )
 
-        yield concurrently_execute(get_event, events, 5)
+        await concurrently_execute(get_event, events, 5)
 
-        yield self._handle_new_events(
+        await self._handle_new_events(
             destination, event_infos,
         )
 
@@ -1253,7 +1247,7 @@ class FederationHandler(BaseHandler):
             # Check whether this room is the result of an upgrade of a room we already know
             # about. If so, migrate over user information
             predecessor = yield self.store.get_room_predecessor(room_id)
-            if not predecessor:
+            if not predecessor or not isinstance(predecessor.get("room_id"), str):
                 return
             old_room_id = predecessor["room_id"]
             logger.debug(
@@ -1281,8 +1275,7 @@ class FederationHandler(BaseHandler):
 
         return True
 
-    @defer.inlineCallbacks
-    def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(self, room_queue):
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
@@ -1298,7 +1291,7 @@ class FederationHandler(BaseHandler):
                     p.room_id,
                 )
                 with nested_logging_context(p.event_id):
-                    yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
                 logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1428,7 +1421,7 @@ class FederationHandler(BaseHandler):
                 user = UserID.from_string(event.state_key)
                 yield self.user_joined_room(user, event.room_id)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
         auth_chain = yield self.store.get_auth_chain(state_ids)
@@ -1496,7 +1489,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
         origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts, room_id, user_id, "leave", content=content,
+            target_hosts, room_id, user_id, "leave", content=content
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
@@ -1937,7 +1930,7 @@ class FederationHandler(BaseHandler):
         context = yield self.state_handler.compute_event_context(event, old_state=state)
 
         if not auth_events:
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             auth_events_ids = yield self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
@@ -2346,12 +2339,12 @@ class FederationHandler(BaseHandler):
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
         }
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
         current_state_ids = dict(current_state_ids)
 
         current_state_ids.update(state_updates)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         prev_state_ids = dict(prev_state_ids)
 
         prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
@@ -2635,7 +2628,7 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"],
         )
         original_invite = None
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         original_invite_id = prev_state_ids.get(key)
         if original_invite_id:
             original_invite = yield self.store.get_event(
@@ -2683,7 +2676,7 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
 
         invite_event = None
@@ -2857,7 +2850,7 @@ class FederationHandler(BaseHandler):
                 room_id=room_id, user_id=user.to_string(), change="joined"
             )
         else:
-            return user_joined_room(self.distributor, user, room_id)
+            return defer.succeed(user_joined_room(self.distributor, user, room_id))
 
     @defer.inlineCallbacks
     def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..44ec3e66ae 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
+        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler):
             as_client_event,
             include_archived,
         )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
 
-        return self.snapshot_cache.set(
-            now_ms,
+        return self.snapshot_cache.wrap(
             key,
-            self._snapshot_all_rooms(
-                user_id, pagin_config, as_client_event, include_archived
-            ),
+            self._snapshot_all_rooms,
+            user_id,
+            pagin_config,
+            as_client_event,
+            include_archived,
         )
 
-    @defer.inlineCallbacks
-    def _snapshot_all_rooms(
+    async def _snapshot_all_rooms(
         self,
         user_id=None,
         pagin_config=None,
@@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
         if include_archived:
             memberships.append(Membership.LEAVE)
 
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_user_where_membership_is(
             user_id=user_id, membership_list=memberships
         )
 
@@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler):
 
         rooms_ret = []
 
-        now_token = yield self.hs.get_event_sources().get_current_token()
+        now_token = await self.hs.get_event_sources().get_current_token()
 
         presence_stream = self.hs.get_event_sources().sources["presence"]
         pagination_config = PaginationConfig(from_token=now_token)
-        presence, _ = yield presence_stream.get_pagination_rows(
+        presence, _ = await presence_stream.get_pagination_rows(
             user, pagination_config.get_source_config("presence"), None
         )
 
         receipt_stream = self.hs.get_event_sources().sources["receipt"]
-        receipt, _ = yield receipt_stream.get_pagination_rows(
+        receipt, _ = await receipt_stream.get_pagination_rows(
             user, pagination_config.get_source_config("receipt"), None
         )
 
-        tags_by_room = yield self.store.get_tags_for_user(user_id)
+        tags_by_room = await self.store.get_tags_for_user(user_id)
 
-        account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+        account_data, account_data_by_room = await self.store.get_account_data_for_user(
             user_id
         )
 
-        public_room_ids = yield self.store.get_public_room_ids()
+        public_room_ids = await self.store.get_public_room_ids()
 
         limit = pagin_config.limit
         if limit is None:
             limit = 10
 
-        @defer.inlineCallbacks
-        def handle_room(event):
+        async def handle_room(event):
             d = {
                 "room_id": event.room_id,
                 "membership": event.membership,
@@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler):
                 time_now = self.clock.time_msec()
                 d["inviter"] = event.sender
 
-                invite_event = yield self.store.get_event(event.event_id)
-                d["invite"] = yield self._event_serializer.serialize_event(
+                invite_event = await self.store.get_event(event.event_id)
+                d["invite"] = await self._event_serializer.serialize_event(
                     invite_event, time_now, as_client_event
                 )
 
@@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler):
                         lambda states: states[event.event_id]
                     )
 
-                (messages, token), current_state = yield make_deferred_yieldable(
+                (messages, token), current_state = await make_deferred_yieldable(
                     defer.gatherResults(
                         [
                             run_in_background(
@@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler):
                     )
                 ).addErrback(unwrapFirstError)
 
-                messages = yield filter_events_for_client(
+                messages = await filter_events_for_client(
                     self.storage, user_id, messages
                 )
 
@@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler):
 
                 d["messages"] = {
                     "chunk": (
-                        yield self._event_serializer.serialize_events(
+                        await self._event_serializer.serialize_events(
                             messages, time_now=time_now, as_client_event=as_client_event
                         )
                     ),
@@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler):
                     "end": end_token.to_string(),
                 }
 
-                d["state"] = yield self._event_serializer.serialize_events(
+                d["state"] = await self._event_serializer.serialize_events(
                     current_state.values(),
                     time_now=time_now,
                     as_client_event=as_client_event,
@@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler):
             except Exception:
                 logger.exception("Failed to get snapshot")
 
-        yield concurrently_execute(handle_room, room_list, 10)
+        await concurrently_execute(handle_room, room_list, 10)
 
         account_data_events = []
         for account_data_type, content in account_data.items():
@@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler):
 
         return ret
 
-    @defer.inlineCallbacks
-    def room_initial_sync(self, requester, room_id, pagin_config=None):
+    async def room_initial_sync(self, requester, room_id, pagin_config=None):
         """Capture the a snapshot of a room. If user is currently a member of
         the room this will be what is currently in the room. If the user left
         the room this will be what was in the room when they left.
@@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler):
             A JSON serialisable dict with the snapshot of the room.
         """
 
-        blocked = yield self.store.is_room_blocked(room_id)
+        blocked = await self.store.is_room_blocked(room_id)
         if blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
         user_id = requester.user.to_string()
 
-        membership, member_event_id = yield self._check_in_room_or_world_readable(
+        membership, member_event_id = await self._check_in_room_or_world_readable(
             room_id, user_id
         )
         is_peeking = member_event_id is None
 
         if membership == Membership.JOIN:
-            result = yield self._room_initial_sync_joined(
+            result = await self._room_initial_sync_joined(
                 user_id, room_id, pagin_config, membership, is_peeking
             )
         elif membership == Membership.LEAVE:
-            result = yield self._room_initial_sync_parted(
+            result = await self._room_initial_sync_parted(
                 user_id, room_id, pagin_config, membership, member_event_id, is_peeking
             )
 
         account_data_events = []
-        tags = yield self.store.get_tags_for_room(user_id, room_id)
+        tags = await self.store.get_tags_for_room(user_id, room_id)
         if tags:
             account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
 
-        account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+        account_data = await self.store.get_account_data_for_room(user_id, room_id)
         for account_data_type, content in account_data.items():
             account_data_events.append({"type": account_data_type, "content": content})
 
@@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler):
 
         return result
 
-    @defer.inlineCallbacks
-    def _room_initial_sync_parted(
+    async def _room_initial_sync_parted(
         self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
     ):
-        room_state = yield self.state_store.get_state_for_events([member_event_id])
+        room_state = await self.state_store.get_state_for_events([member_event_id])
 
         room_state = room_state[member_event_id]
 
@@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler):
         if limit is None:
             limit = 10
 
-        stream_token = yield self.store.get_stream_token_for_event(member_event_id)
+        stream_token = await self.store.get_stream_token_for_event(member_event_id)
 
-        messages, token = yield self.store.get_recent_events_for_room(
+        messages, token = await self.store.get_recent_events_for_room(
             room_id, limit=limit, end_token=stream_token
         )
 
-        messages = yield filter_events_for_client(
+        messages = await filter_events_for_client(
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
@@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler):
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    yield self._event_serializer.serialize_events(messages, time_now)
+                    await self._event_serializer.serialize_events(messages, time_now)
                 ),
                 "start": start_token.to_string(),
                 "end": end_token.to_string(),
             },
             "state": (
-                yield self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     room_state.values(), time_now
                 )
             ),
@@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler):
             "receipts": [],
         }
 
-    @defer.inlineCallbacks
-    def _room_initial_sync_joined(
+    async def _room_initial_sync_joined(
         self, user_id, room_id, pagin_config, membership, is_peeking
     ):
-        current_state = yield self.state.get_current_state(room_id=room_id)
+        current_state = await self.state.get_current_state(room_id=room_id)
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
-        state = yield self._event_serializer.serialize_events(
+        state = await self._event_serializer.serialize_events(
             current_state.values(), time_now
         )
 
-        now_token = yield self.hs.get_event_sources().get_current_token()
+        now_token = await self.hs.get_event_sources().get_current_token()
 
         limit = pagin_config.limit if pagin_config else None
         if limit is None:
@@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler):
 
         presence_handler = self.hs.get_presence_handler()
 
-        @defer.inlineCallbacks
-        def get_presence():
+        async def get_presence():
             # If presence is disabled, return an empty list
             if not self.hs.config.use_presence:
                 return []
 
-            states = yield presence_handler.get_states(
+            states = await presence_handler.get_states(
                 [m.user_id for m in room_members], as_event=True
             )
 
             return states
 
-        @defer.inlineCallbacks
-        def get_receipts():
-            receipts = yield self.store.get_linearized_receipts_for_room(
+        async def get_receipts():
+            receipts = await self.store.get_linearized_receipts_for_room(
                 room_id, to_key=now_token.receipt_key
             )
             if not receipts:
                 receipts = []
             return receipts
 
-        presence, receipts, (messages, token) = yield make_deferred_yieldable(
+        presence, receipts, (messages, token) = await make_deferred_yieldable(
             defer.gatherResults(
                 [
                     run_in_background(get_presence),
@@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler):
             ).addErrback(unwrapFirstError)
         )
 
-        messages = yield filter_events_for_client(
+        messages = await filter_events_for_client(
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
@@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler):
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    yield self._event_serializer.serialize_events(messages, time_now)
+                    await self._event_serializer.serialize_events(messages, time_now)
                 ),
                 "start": start_token.to_string(),
                 "end": end_token.to_string(),
@@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler):
 
         return ret
 
-    @defer.inlineCallbacks
-    def _check_in_room_or_world_readable(self, room_id, user_id):
+    async def _check_in_room_or_world_readable(self, room_id, user_id):
         try:
             # check_user_was_in_room will return the most recent membership
             # event for the user if:
             #  * The user is a non-guest user, and was ever in the room
             #  * The user is a guest user, and has joined the room
             # else it will throw.
-            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            member_event = await self.auth.check_user_was_in_room(room_id, user_id)
             return member_event.membership, member_event.event_id
         except AuthError:
-            visibility = yield self.state_handler.get_current_state(
+            visibility = await self.state_handler.get_current_state(
                 room_id, EventTypes.RoomHistoryVisibility, ""
             )
             if (
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..4ad752205f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
 from synapse.util.async_helpers import Linearizer
@@ -514,7 +515,7 @@ class EventCreationHandler(object):
             # federation as well as those created locally. As of room v3, aliases events
             # can be created by users that are not in the room, therefore we have to
             # tolerate them in event_auth.check().
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
             prev_event = (
                 yield self.store.get_event(prev_event_id, allow_none=True)
@@ -664,7 +665,7 @@ class EventCreationHandler(object):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
             return
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
             if event.type == EventTypes.Redaction:
                 original_event = yield self.store.get_event(
                     event.redacts,
-                    check_redacted=False,
+                    redact_behaviour=EventRedactBehaviour.AS_IS,
                     get_prev_content=False,
                     allow_rejected=False,
                     allow_none=True,
@@ -913,7 +914,7 @@ class EventCreationHandler(object):
                 def is_inviter_member_event(e):
                     return e.type == EventTypes.Member and e.sender == event.sender
 
-                current_state_ids = yield context.get_current_state_ids(self.store)
+                current_state_ids = yield context.get_current_state_ids()
 
                 state_to_include_ids = [
                     e_id
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
         if event.type == EventTypes.Redaction:
             original_event = yield self.store.get_event(
                 event.redacts,
-                check_redacted=False,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
                 get_prev_content=False,
                 allow_rejected=False,
                 allow_none=True,
@@ -966,7 +967,7 @@ class EventCreationHandler(object):
                 if original_event.room_id != event.room_id:
                     raise SynapseError(400, "Cannot redact event from a different room")
 
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             auth_events_ids = yield self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
@@ -988,7 +989,7 @@ class EventCreationHandler(object):
                 event.internal_metadata.recheck_redaction = False
 
         if event.type == EventTypes.Create:
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
 
             await self.storage.purge_events.purge_room(room_id)
 
-    @defer.inlineCallbacks
-    def get_messages(
+    async def get_messages(
         self,
         requester,
         room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
             room_token = pagin_config.from_token.room_key
         else:
             pagin_config.from_token = (
-                yield self.hs.get_event_sources().get_current_token_for_pagination()
+                await self.hs.get_event_sources().get_current_token_for_pagination()
             )
             room_token = pagin_config.from_token.room_key
 
@@ -319,11 +318,11 @@ class PaginationHandler(object):
 
         source_config = pagin_config.get_source_config("room")
 
-        with (yield self.pagination_lock.read(room_id)):
+        with (await self.pagination_lock.read(room_id)):
             (
                 membership,
                 member_event_id,
-            ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+            ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
 
             if source_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
                 if room_token.topological:
                     max_topo = room_token.topological
                 else:
-                    max_topo = yield self.store.get_max_topological_token(
+                    max_topo = await self.store.get_max_topological_token(
                         room_id, room_token.stream
                     )
 
@@ -339,18 +338,18 @@ class PaginationHandler(object):
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
-                    leave_token = yield self.store.get_topological_token_for_event(
+                    leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
                     leave_token = RoomStreamToken.parse(leave_token)
                     if leave_token.topological < max_topo:
                         source_config.from_key = str(leave_token)
 
-                yield self.hs.get_handlers().federation_handler.maybe_backfill(
+                await self.hs.get_handlers().federation_handler.maybe_backfill(
                     room_id, max_topo
                 )
 
-            events, next_key = yield self.store.paginate_room_events(
+            events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
                 from_key=source_config.from_key,
                 to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
             if event_filter:
                 events = event_filter.filter(events)
 
-            events = yield filter_events_for_client(
+            events = await filter_events_for_client(
                 self.storage, user_id, events, is_peeking=(member_event_id is None)
             )
 
@@ -385,19 +384,19 @@ class PaginationHandler(object):
                 (EventTypes.Member, event.sender) for event in events
             )
 
-            state_ids = yield self.state_store.get_state_ids_for_event(
+            state_ids = await self.state_store.get_state_ids_for_event(
                 events[0].event_id, state_filter=state_filter
             )
 
             if state_ids:
-                state = yield self.store.get_events(list(state_ids.values()))
+                state = await self.store.get_events(list(state_ids.values()))
                 state = state.values()
 
         time_now = self.clock.time_msec()
 
         chunk = {
             "chunk": (
-                yield self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     events, time_now, as_client_event=as_client_event
                 )
             ),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
         }
 
         if state:
-            chunk["state"] = yield self._event_serializer.serialize_events(
+            chunk["state"] = await self._event_serializer.serialize_events(
                 state, time_now, as_client_event=as_client_event
             )
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index eda15bc623..240c4add12 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -230,7 +230,7 @@ class PresenceHandler(object):
         is some spurious presence changes that will self-correct.
         """
         # If the DB pool has already terminated, don't try updating
-        if not self.hs.get_db_pool().running:
+        if not self.store.database.is_running():
             return
 
         logger.info(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 1e5a4613c9..f9579d69ee 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler):
                 be found to be in any room the server is in, and therefore the query
                 is denied.
         """
+
         # Implementation of MSC1301: don't allow looking up profiles if the
         # requester isn't in the same room as the target. We expect requester to
         # be None when this function is called outside of a profile query, e.g.
         # when building a membership event. In this case, we must allow the
         # lookup.
-        if not self.hs.config.require_auth_for_profile_requests or not requester:
+        if (
+            not self.hs.config.limit_profile_requests_to_users_who_share_rooms
+            or not requester
+        ):
             return
 
         # Always allow the user to query their own profile.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 60b8bbc7a5..89c9118b26 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler):
             requester, tombstone_event, tombstone_context
         )
 
-        old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+        old_room_state = yield tombstone_context.get_current_state_ids()
 
         # update any aliases
         yield self._move_aliases_to_new_room(
@@ -1011,15 +1011,3 @@ class RoomEventSource(object):
 
     def get_current_key_for_room(self, room_id):
         return self.store.get_room_events_max_id(room_id)
-
-    @defer.inlineCallbacks
-    def get_pagination_rows(self, user, config, key):
-        events, next_key = yield self.store.paginate_room_events(
-            room_id=key,
-            from_key=config.from_key,
-            to_key=config.to_key,
-            direction=config.direction,
-            limit=config.limit,
-        )
-
-        return (events, next_key)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7b7270fc61..44c5e3239c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -193,7 +193,7 @@ class RoomMemberHandler(object):
             requester, event, context, extra_users=[target], ratelimit=ratelimit
         )
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
@@ -601,7 +601,7 @@ class RoomMemberHandler(object):
         if prev_event is not None:
             return
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         if event.membership == Membership.JOIN:
             if requester.is_guest:
                 guest_can_join = yield self._can_guest_join(prev_state_ids)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.storage.state import StateFilter
 from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
+        self.auth = hs.get_auth()
 
     @defer.inlineCallbacks
     def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
             room_id (str): id of the room to search through.
 
         Returns:
-            Deferred[iterable[unicode]]: predecessor room ids
+            Deferred[iterable[str]]: predecessor room ids
         """
 
         historical_room_ids = []
 
-        while True:
-            predecessor = yield self.store.get_room_predecessor(room_id)
+        # The initial room must have been known for us to get this far
+        predecessor = yield self.store.get_room_predecessor(room_id)
 
-            # If no predecessor, assume we've hit a dead end
+        while True:
             if not predecessor:
+                # We have reached the end of the chain of predecessors
+                break
+
+            if not isinstance(predecessor.get("room_id"), str):
+                # This predecessor object is malformed. Exit here
+                break
+
+            predecessor_room_id = predecessor["room_id"]
+
+            # Don't add it to the list until we have checked that we are in the room
+            try:
+                next_predecessor_room = yield self.store.get_room_predecessor(
+                    predecessor_room_id
+                )
+            except NotFoundError:
+                # The predecessor is not a known room, so we are done here
                 break
 
-            # Add predecessor's room ID
-            historical_room_ids.append(predecessor["room_id"])
+            historical_room_ids.append(predecessor_room_id)
 
-            # Scan through the old room for further predecessors
-            room_id = predecessor["room_id"]
+            # And repeat
+            predecessor = next_predecessor_room
 
         return historical_room_ids
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6f78454322..b635c339ed 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -317,6 +317,3 @@ class TypingNotificationEventSource(object):
 
     def get_current_key(self):
         return self.get_typing_handler()._latest_room_serial
-
-    def get_pagination_rows(self, user, pagination_config, key):
-        return [], pagination_config.from_key
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 03934956f4..c0b9384189 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -171,7 +171,7 @@ class LogProducer(object):
 
     def stopProducing(self):
         self._paused = True
-        self._buffer = None
+        self._buffer = deque()
 
     def resumeProducing(self):
         self._paused = False
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..33b322209d 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
 See doc/log_contexts.rst for details on how this works.
 """
 
+import inspect
 import logging
 import threading
 import types
@@ -404,6 +405,9 @@ class LoggingContext(object):
         """
         current = get_thread_resource_usage()
 
+        # Indicate to mypy that we know that self.usage_start is None.
+        assert self.usage_start is not None
+
         utime_delta = current.ru_utime - self.usage_start.ru_utime
         stime_delta = current.ru_stime - self.usage_start.ru_stime
 
@@ -612,7 +616,8 @@ def run_in_background(f, *args, **kwargs):
 
 
 def make_deferred_yieldable(deferred):
-    """Given a deferred, make it follow the Synapse logcontext rules:
+    """Given a deferred (or coroutine), make it follow the Synapse logcontext
+    rules:
 
     If the deferred has completed (or is not actually a Deferred), essentially
     does nothing (just returns another completed deferred with the
@@ -624,6 +629,13 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
+    if inspect.isawaitable(deferred):
+        # If we're given a coroutine we convert it to a deferred so that we
+        # run it and find out if it immediately finishes, it it does then we
+        # don't need to fiddle with log contexts at all and can return
+        # immediately.
+        deferred = defer.ensureDeferred(deferred)
+
     if not isinstance(deferred, defer.Deferred):
         return deferred
 
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7881780760..7d9f5a38d9 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object):
 
     @defer.inlineCallbacks
     def _get_power_levels_and_sender_level(self, event, context):
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         pl_event_id = prev_state_ids.get(POWER_KEY)
         if pl_event_id:
             # fastpath: if there's a power level event, that's all we need, and
@@ -304,7 +304,7 @@ class RulesForRoom(object):
 
                 push_rules_delta_state_cache_metric.inc_hits()
             else:
-                current_state_ids = yield context.get_current_state_ids(self.store)
+                current_state_ids = yield context.get_current_state_ids()
                 push_rules_delta_state_cache_metric.inc_misses()
 
             push_rules_state_size_counter.inc(len(current_state_ids))
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index f277aeb131..8ad0bf5936 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -80,9 +80,11 @@ class PusherFactory(object):
         return EmailPusher(self.hs, pusherdict, mailer)
 
     def _app_name_from_pusherdict(self, pusherdict):
-        if "data" in pusherdict and "brand" in pusherdict["data"]:
-            app_name = pusherdict["data"]["brand"]
-        else:
-            app_name = self.config.email_app_name
+        data = pusherdict["data"]
 
-        return app_name
+        if isinstance(data, dict):
+            brand = data.get("brand")
+            if isinstance(brand, str):
+                return brand
+
+        return self.config.email_app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 0f6992202d..b9dca5bc63 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -232,7 +232,6 @@ class PusherPool:
             Deferred
         """
         pushers = yield self.store.get_all_pushers()
-        logger.info("Starting %d pushers", len(pushers))
 
         # Stagger starting up the pushers so we don't completely drown the
         # process on start up.
@@ -245,7 +244,7 @@ class PusherPool:
         """Start the given pusher
 
         Args:
-            pusherdict (dict):
+            pusherdict (dict): dict with the values pulled from the db table
 
         Returns:
             Deferred[EmailPusher|HttpPusher]
@@ -254,7 +253,8 @@ class PusherPool:
             p = self.pusher_factory.create_pusher(pusherdict)
         except PusherConfigException as e:
             logger.warning(
-                "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s",
+                "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
+                pusherdict["id"],
                 pusherdict.get("user_name"),
                 pusherdict.get("app_id"),
                 pusherdict.get("pushkey"),
@@ -262,7 +262,9 @@ class PusherPool:
             )
             return
         except Exception:
-            logger.exception("Couldn't start a pusher: caught Exception")
+            logger.exception(
+                "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+            )
             return
 
         if not p:
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 9af4e7e173..49a3251372 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.clock = hs.get_clock()
         self.federation_handler = hs.get_handlers().federation_handler
 
@@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 EventType = event_type_from_format_version(format_ver)
                 event = EventType(event_dict, internal_metadata, rejected_reason)
 
-                context = EventContext.deserialize(self.store, event_payload["context"])
+                context = EventContext.deserialize(
+                    self.storage, event_payload["context"]
+                )
 
                 event_and_contexts.append((event, context))
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 9bafd60b14..84b92f16ad 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         self.event_creation_handler = hs.get_event_creation_handler()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.clock = hs.get_clock()
 
     @staticmethod
@@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             event = EventType(event_dict, internal_metadata, rejected_reason)
 
             requester = Requester.deserialize(self.store, content["requester"])
-            context = EventContext.deserialize(self.store, content["context"])
+            context = EventContext.deserialize(self.storage, content["context"])
 
             ratelimit = content["ratelimit"]
             extra_users = [UserID.from_string(u) for u in content["extra_users"]]
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 0791866f55..6f6b7aed6e 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 
 logger = logging.getLogger(__name__)
 
+ALLOWED_KEYS = {
+    "app_display_name",
+    "app_id",
+    "data",
+    "device_display_name",
+    "kind",
+    "lang",
+    "profile_tag",
+    "pushkey",
+}
+
 
 class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
@@ -43,23 +54,11 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
 
-        allowed_keys = [
-            "app_display_name",
-            "app_id",
-            "data",
-            "device_display_name",
-            "kind",
-            "lang",
-            "profile_tag",
-            "pushkey",
-        ]
-
-        for p in pushers:
-            for k, v in list(p.items()):
-                if k not in allowed_keys:
-                    del p[k]
-
-        return 200, {"pushers": pushers}
+        filtered_pushers = list(
+            {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
+        )
+
+        return 200, {"pushers": filtered_pushers}
 
     def on_OPTIONS(self, _):
         return 200, {}
diff --git a/synapse/server.py b/synapse/server.py
index 2db3dab221..7926867b77 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -25,7 +25,6 @@ import abc
 import logging
 import os
 
-from twisted.enterprise import adbapi
 from twisted.mail.smtp import sendmail
 from twisted.web.client import BrowserLikePolicyForHTTPS
 
@@ -34,6 +33,7 @@ from synapse.api.filtering import Filtering
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice.api import ApplicationServiceApi
 from synapse.appservice.scheduler import ApplicationServiceScheduler
+from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
 from synapse.crypto.keyring import Keyring
 from synapse.events.builder import EventBuilderFactory
@@ -132,7 +132,6 @@ class HomeServer(object):
 
     DEPENDENCIES = [
         "http_client",
-        "db_pool",
         "federation_client",
         "federation_server",
         "handlers",
@@ -209,16 +208,18 @@ class HomeServer(object):
     # instantiated during setup() for future return by get_datastore()
     DATASTORE_CLASS = abc.abstractproperty()
 
-    def __init__(self, hostname, reactor=None, **kwargs):
+    def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
         """
         Args:
             hostname : The hostname for the server.
+            config: The full config for the homeserver.
         """
         if not reactor:
             from twisted.internet import reactor
 
         self._reactor = reactor
         self.hostname = hostname
+        self.config = config
         self._building = {}
         self._listening_services = []
         self.start_time = None
@@ -237,10 +238,8 @@ class HomeServer(object):
 
     def setup(self):
         logger.info("Setting up.")
-        with self.get_db_conn() as conn:
-            self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
-            conn.commit()
         self.start_time = int(self.get_clock().time())
+        self.datastores = DataStores(self.DATASTORE_CLASS, self)
         logger.info("Finished setting up.")
 
     def setup_master(self):
@@ -274,6 +273,9 @@ class HomeServer(object):
     def get_datastore(self):
         return self.datastores.main
 
+    def get_datastores(self):
+        return self.datastores
+
     def get_config(self):
         return self.config
 
@@ -423,31 +425,6 @@ class HomeServer(object):
         )
         return MatrixFederationHttpClient(self, tls_client_options_factory)
 
-    def build_db_pool(self):
-        name = self.db_config["name"]
-
-        return adbapi.ConnectionPool(
-            name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {})
-        )
-
-    def get_db_conn(self, run_new_connection=True):
-        """Makes a new connection to the database, skipping the db pool
-
-        Returns:
-            Connection: a connection object implementing the PEP-249 spec
-        """
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v
-            for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def build_media_repository_resource(self):
         # build the media repo resource. This indirects through the HomeServer
         # to ensure that we only have a single instance of
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 0e75e94c6f..5accc071ab 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.utils import log_function
 from synapse.state import v1, v2
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -655,7 +656,7 @@ class StateResolutionStore(object):
 
         return self.store.get_events(
             event_ids,
-            check_redacted=False,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
             get_prev_content=False,
             allow_rejected=allow_rejected,
         )
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7637b5dc0..88546ad614 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -40,7 +40,7 @@ class SQLBaseStore(object):
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
+        self.database_engine = database.engine
         self.db = database
         self.rand = random.SystemRandom()
 
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cafedd5c0d..d20df5f076 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,24 +13,55 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.database import Database
+import logging
+
+from synapse.storage.data_stores.state import StateGroupDataStore
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 
+logger = logging.getLogger(__name__)
+
 
 class DataStores(object):
     """The various data stores.
 
     These are low level interfaces to physical databases.
+
+    Attributes:
+        main (DataStore)
     """
 
-    def __init__(self, main_store_class, db_conn, hs):
+    def __init__(self, main_store_class, hs):
         # Note we pass in the main store class here as workers use a different main
         # store.
-        database = Database(hs)
 
-        # Check that db is correctly configured.
-        database.engine.check_database(db_conn.cursor())
+        self.databases = []
+
+        for database_config in hs.config.database.databases:
+            db_name = database_config.name
+            engine = create_engine(database_config.config)
+
+            with make_conn(database_config, engine) as db_conn:
+                logger.info("Preparing database %r...", db_name)
+
+                engine.check_database(db_conn.cursor())
+                prepare_database(
+                    db_conn, engine, hs.config, data_stores=database_config.data_stores,
+                )
+
+                database = Database(hs, database_config, engine)
+
+                if "main" in database_config.data_stores:
+                    logger.info("Starting 'main' data store")
+                    self.main = main_store_class(database, db_conn, hs)
+
+                if "state" in database_config.data_stores:
+                    logger.info("Starting 'state' data store")
+                    self.state = StateGroupDataStore(database, db_conn, hs)
+
+                db_conn.commit()
 
-        prepare_database(db_conn, database.engine, config=hs.config)
+                self.databases.append(database)
 
-        self.main = main_store_class(database, db_conn, hs)
+                logger.info("Database %r prepared", db_name)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..13f4c9c72e 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
     def _update_client_ips_batch(self):
 
         # If the DB pool has already terminated, don't try updating
-        if not self.hs.get_db_pool().running:
+        if not self.db.is_running():
             return
 
         to_update = self._batch_row_update
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self.db.simple_upsert_txn(
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
-                        values={
+                        updatevalues={
                             "user_agent": user_agent,
                             "last_seen": last_seen,
                             "ip": ip,
                         },
-                        lock=False,
                     )
             except Exception as e:
                 # Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 85cfa16850..0613b49f4a 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -358,21 +358,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     def _add_messages_to_local_device_inbox_txn(
         self, txn, stream_id, messages_by_user_then_device
     ):
-        # Compatible method of performing an upsert
-        sql = "SELECT stream_id FROM device_max_stream_id"
-
-        txn.execute(sql)
-        rows = txn.fetchone()
-        if rows:
-            db_stream_id = rows[0]
-            if db_stream_id < stream_id:
-                # Insert the new stream_id
-                sql = "UPDATE device_max_stream_id SET stream_id = ?"
-        else:
-            # No rows, perform an insert
-            sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)"
-
-        txn.execute(sql, (stream_id,))
+        sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
+        txn.execute(sql, (stream_id, stream_id))
 
         local_by_user_then_device = {}
         for user_id, messages_by_device in messages_by_user_then_device.items():
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 38cd0ca9b8..e551606f9d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,15 +14,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Dict, List
+
 from six import iteritems
 
 from canonicaljson import encode_canonical_json, json
 
+from twisted.enterprise.adbapi import Connection
 from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
             user_id (str): the user whose key is being requested
-            key_type (str): the type of key that is being set: either 'master'
+            key_type (str): the type of key that is being requested: either 'master'
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
             from_user_id (str): if specified, signatures made by this user on
@@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         """Returns a user's cross-signing key.
 
         Args:
-            user_id (str): the user whose self-signing key is being requested
-            key_type (str): the type of cross-signing key to get
+            user_id (str): the user whose key is being requested
+            key_type (str): the type of key that is being requested: either 'master'
+                for a master key, 'self_signing' for a self-signing key, or
+                'user_signing' for a user-signing key
             from_user_id (str): if specified, signatures made by this user on
                 the self-signing key will be included in the result
 
@@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             from_user_id,
         )
 
+    @cached(num_args=1)
+    def _get_bare_e2e_cross_signing_keys(self, user_id):
+        """Dummy function.  Only used to make a cache for
+        _get_bare_e2e_cross_signing_keys_bulk.
+        """
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_bare_e2e_cross_signing_keys",
+        list_name="user_ids",
+        num_args=1,
+    )
+    def _get_bare_e2e_cross_signing_keys_bulk(
+        self, user_ids: List[str]
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing keys for a set of users.  The output of this
+        function should be passed to _get_e2e_cross_signing_signatures_txn if
+        the signatures for the calling user need to be fetched.
+
+        Args:
+            user_ids (list[str]): the users whose keys are being requested
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  If a user's cross-signing keys were not found, either
+                their user ID will not be in the dict, or their user ID will map
+                to None.
+
+        """
+        return self.db.runInteraction(
+            "get_bare_e2e_cross_signing_keys_bulk",
+            self._get_bare_e2e_cross_signing_keys_bulk_txn,
+            user_ids,
+        )
+
+    def _get_bare_e2e_cross_signing_keys_bulk_txn(
+        self, txn: Connection, user_ids: List[str],
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing keys for a set of users.  The output of this
+        function should be passed to _get_e2e_cross_signing_signatures_txn if
+        the signatures for the calling user need to be fetched.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            user_ids (list[str]): the users whose keys are being requested
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  If a user's cross-signing keys were not found, their user
+                ID will not be in the dict.
+
+        """
+        result = {}
+
+        batch_size = 100
+        chunks = [
+            user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
+        ]
+        for user_chunk in chunks:
+            sql = """
+                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+                  FROM e2e_cross_signing_keys k
+                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+                                FROM e2e_cross_signing_keys
+                               GROUP BY user_id, keytype) s
+                 USING (user_id, stream_id, keytype)
+                 WHERE k.user_id IN (%s)
+            """ % (
+                ",".join("?" for u in user_chunk),
+            )
+            query_params = []
+            query_params.extend(user_chunk)
+
+            txn.execute(sql, query_params)
+            rows = self.db.cursor_to_dict(txn)
+
+            for row in rows:
+                user_id = row["user_id"]
+                key_type = row["keytype"]
+                key = json.loads(row["keydata"])
+                user_info = result.setdefault(user_id, {})
+                user_info[key_type] = key
+
+        return result
+
+    def _get_e2e_cross_signing_signatures_txn(
+        self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing signatures made by a user on a set of keys.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+                key data.  This dict will be modified to add signatures.
+            from_user_id (str): fetch the signatures made by this user
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  The return value will be the same as the keys argument,
+                with the modifications included.
+        """
+
+        # find out what cross-signing keys (a.k.a. devices) we need to get
+        # signatures for.  This is a map of (user_id, device_id) to key type
+        # (device_id is the key's public part).
+        devices = {}
+
+        for user_id, user_info in keys.items():
+            if user_info is None:
+                continue
+            for key_type, key in user_info.items():
+                device_id = None
+                for k in key["keys"].values():
+                    device_id = k
+                devices[(user_id, device_id)] = key_type
+
+        device_list = list(devices)
+
+        # split into batches
+        batch_size = 100
+        chunks = [
+            device_list[i : i + batch_size]
+            for i in range(0, len(device_list), batch_size)
+        ]
+        for user_chunk in chunks:
+            sql = """
+                SELECT target_user_id, target_device_id, key_id, signature
+                  FROM e2e_cross_signing_signatures
+                 WHERE user_id = ?
+                   AND (%s)
+            """ % (
+                " OR ".join(
+                    "(target_user_id = ? AND target_device_id = ?)" for d in devices
+                )
+            )
+            query_params = [from_user_id]
+            for item in devices:
+                # item is a (user_id, device_id) tuple
+                query_params.extend(item)
+
+            txn.execute(sql, query_params)
+            rows = self.db.cursor_to_dict(txn)
+
+            # and add the signatures to the appropriate keys
+            for row in rows:
+                key_id = row["key_id"]
+                target_user_id = row["target_user_id"]
+                target_device_id = row["target_device_id"]
+                key_type = devices[(target_user_id, target_device_id)]
+                # We need to copy everything, because the result may have come
+                # from the cache.  dict.copy only does a shallow copy, so we
+                # need to recursively copy the dicts that will be modified.
+                user_info = keys[target_user_id] = keys[target_user_id].copy()
+                target_user_key = user_info[key_type] = user_info[key_type].copy()
+                if "signatures" in target_user_key:
+                    signatures = target_user_key["signatures"] = target_user_key[
+                        "signatures"
+                    ].copy()
+                    if from_user_id in signatures:
+                        user_sigs = signatures[from_user_id] = signatures[from_user_id]
+                        user_sigs[key_id] = row["signature"]
+                    else:
+                        signatures[from_user_id] = {key_id: row["signature"]}
+                else:
+                    target_user_key["signatures"] = {
+                        from_user_id: {key_id: row["signature"]}
+                    }
+
+        return keys
+
+    @defer.inlineCallbacks
+    def get_e2e_cross_signing_keys_bulk(
+        self, user_ids: List[str], from_user_id: str = None
+    ) -> defer.Deferred:
+        """Returns the cross-signing keys for a set of users.
+
+        Args:
+            user_ids (list[str]): the users whose keys are being requested
+            from_user_id (str): if specified, signatures made by this user on
+                the self-signing keys will be included in the result
+
+        Returns:
+            Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+                key data.  If a user's cross-signing keys were not found, either
+                their user ID will not be in the dict, or their user ID will map
+                to None.
+        """
+
+        result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+        if from_user_id:
+            result = yield self.db.runInteraction(
+                "get_e2e_cross_signing_signatures",
+                self._get_e2e_cross_signing_signatures_txn,
+                result,
+                from_user_id,
+            )
+
+        return result
+
     def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
         """Return a list of changes from the user signature stream to notify remotes.
         Note that the user signature stream represents when a user signs their
@@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 },
             )
 
+        self._invalidate_cache_and_stream(
+            txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+        )
+
     def set_e2e_cross_signing_key(self, user_id, key_type, key):
         """Set a user's cross-signing key.
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 998bba1aad..58f35d7f56 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1757,163 +1757,6 @@ class EventsStore(
 
         return state_groups
 
-    def purge_unreferenced_state_groups(
-        self, room_id: str, state_groups_to_delete
-    ) -> defer.Deferred:
-        """Deletes no longer referenced state groups and de-deltas any state
-        groups that reference them.
-
-        Args:
-            room_id: The room the state groups belong to (must all be in the
-                same room).
-            state_groups_to_delete (Collection[int]): Set of all state groups
-                to delete.
-        """
-
-        return self.db.runInteraction(
-            "purge_unreferenced_state_groups",
-            self._purge_unreferenced_state_groups,
-            room_id,
-            state_groups_to_delete,
-        )
-
-    def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
-        logger.info(
-            "[purge] found %i state groups to delete", len(state_groups_to_delete)
-        )
-
-        rows = self.db.simple_select_many_txn(
-            txn,
-            table="state_group_edges",
-            column="prev_state_group",
-            iterable=state_groups_to_delete,
-            keyvalues={},
-            retcols=("state_group",),
-        )
-
-        remaining_state_groups = set(
-            row["state_group"]
-            for row in rows
-            if row["state_group"] not in state_groups_to_delete
-        )
-
-        logger.info(
-            "[purge] de-delta-ing %i remaining state groups",
-            len(remaining_state_groups),
-        )
-
-        # Now we turn the state groups that reference to-be-deleted state
-        # groups to non delta versions.
-        for sg in remaining_state_groups:
-            logger.info("[purge] de-delta-ing remaining state group %s", sg)
-            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
-            curr_state = curr_state[sg]
-
-            self.db.simple_delete_txn(
-                txn, table="state_groups_state", keyvalues={"state_group": sg}
-            )
-
-            self.db.simple_delete_txn(
-                txn, table="state_group_edges", keyvalues={"state_group": sg}
-            )
-
-            self.db.simple_insert_many_txn(
-                txn,
-                table="state_groups_state",
-                values=[
-                    {
-                        "state_group": sg,
-                        "room_id": room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                        "event_id": state_id,
-                    }
-                    for key, state_id in iteritems(curr_state)
-                ],
-            )
-
-        logger.info("[purge] removing redundant state groups")
-        txn.executemany(
-            "DELETE FROM state_groups_state WHERE state_group = ?",
-            ((sg,) for sg in state_groups_to_delete),
-        )
-        txn.executemany(
-            "DELETE FROM state_groups WHERE id = ?",
-            ((sg,) for sg in state_groups_to_delete),
-        )
-
-    @defer.inlineCallbacks
-    def get_previous_state_groups(self, state_groups):
-        """Fetch the previous groups of the given state groups.
-
-        Args:
-            state_groups (Iterable[int])
-
-        Returns:
-            Deferred[dict[int, int]]: mapping from state group to previous
-            state group.
-        """
-
-        rows = yield self.db.simple_select_many_batch(
-            table="state_group_edges",
-            column="prev_state_group",
-            iterable=state_groups,
-            keyvalues={},
-            retcols=("prev_state_group", "state_group"),
-            desc="get_previous_state_groups",
-        )
-
-        return {row["state_group"]: row["prev_state_group"] for row in rows}
-
-    def purge_room_state(self, room_id, state_groups_to_delete):
-        """Deletes all record of a room from state tables
-
-        Args:
-            room_id (str):
-            state_groups_to_delete (list[int]): State groups to delete
-        """
-
-        return self.db.runInteraction(
-            "purge_room_state",
-            self._purge_room_state_txn,
-            room_id,
-            state_groups_to_delete,
-        )
-
-    def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
-        # first we have to delete the state groups states
-        logger.info("[purge] removing %s from state_groups_state", room_id)
-
-        self.db.simple_delete_many_txn(
-            txn,
-            table="state_groups_state",
-            column="state_group",
-            iterable=state_groups_to_delete,
-            keyvalues={},
-        )
-
-        # ... and the state group edges
-        logger.info("[purge] removing %s from state_group_edges", room_id)
-
-        self.db.simple_delete_many_txn(
-            txn,
-            table="state_group_edges",
-            column="state_group",
-            iterable=state_groups_to_delete,
-            keyvalues={},
-        )
-
-        # ... and the state groups
-        logger.info("[purge] removing %s from state_groups", room_id)
-
-        self.db.simple_delete_many_txn(
-            txn,
-            table="state_groups",
-            column="id",
-            iterable=state_groups_to_delete,
-            keyvalues={},
-        )
-
     async def is_event_after(self, event_id1, event_id2):
         """Returns True if event_id1 is after event_id2 in the stream
         """
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 9ee117ce0f..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
+from typing import List, Optional
 
 from canonicaljson import json
+from constantly import NamedConstant, Names
 
 from twisted.internet import defer
 
@@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
+class EventRedactBehaviour(Names):
+    """
+    What to do when retrieving a redacted event from the database.
+    """
+
+    AS_IS = NamedConstant()
+    REDACT = NamedConstant()
+    BLOCK = NamedConstant()
+
+
 class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
@@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_event(
         self,
-        event_id,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
-        allow_none=False,
-        check_room_id=None,
+        event_id: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: bool = False,
+        check_room_id: Optional[str] = None,
     ):
         """Get an event from the database by event_id.
 
         Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_id: The event_id of the event to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
+            allow_rejected: If True return rejected events.
+            allow_none: If True, return None if no event found, if
                 False throw a NotFoundError
-            check_room_id (str|None): if not None, check the room of the found event.
+            check_room_id: if not None, check the room of the found event.
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
@@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         events = yield self.get_events_as_list(
             [event_id],
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible
+                values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True return rejected events.
 
         Returns:
             Deferred : Dict from event_id to event.
         """
         events = yield self.get_events_as_list(
             event_ids,
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events_as_list(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True, return rejected events.
 
         Returns:
             Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
                     # Update the cache to save doing the checks again.
                     entry.event.internal_metadata.recheck_redaction = False
 
-            if check_redacted and entry.redacted_event:
-                event = entry.redacted_event
-            else:
-                event = entry.event
+            event = entry.event
+
+            if entry.redacted_event:
+                if redact_behaviour == EventRedactBehaviour.BLOCK:
+                    # Skip this event
+                    continue
+                elif redact_behaviour == EventRedactBehaviour.REDACT:
+                    event = entry.redacted_event
 
             events.append(event)
 
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 5ba13aa973..e2673ae073 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -244,7 +244,7 @@ class PushRulesWorkerStore(
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        current_state_ids = yield context.get_current_state_ids(self)
+        current_state_ids = yield context.get_current_state_ids()
         result = yield self._bulk_get_push_rules_for_room(
             event.room_id, state_group, current_state_ids, event=event
         )
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index f07309ef09..6b03233262 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,8 +15,7 @@
 # limitations under the License.
 
 import logging
-
-import six
+from typing import Iterable, Iterator
 
 from canonicaljson import encode_canonical_json, json
 
@@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 
 logger = logging.getLogger(__name__)
 
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
-
 
 class PusherWorkerStore(SQLBaseStore):
-    def _decode_pushers_rows(self, rows):
+    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+        """JSON-decode the data in the rows returned from the `pushers` table
+
+        Drops any rows whose data cannot be decoded
+        """
         for r in rows:
             dataJson = r["data"]
-            r["data"] = None
             try:
-                if isinstance(dataJson, db_binary_type):
-                    dataJson = str(dataJson).decode("UTF8")
-
                 r["data"] = json.loads(dataJson)
             except Exception as e:
                 logger.warning(
@@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore):
                     dataJson,
                     e.args[0],
                 )
-                pass
-
-            if isinstance(r["pushkey"], db_binary_type):
-                r["pushkey"] = str(r["pushkey"]).decode("UTF8")
+                continue
 
-        return rows
+            yield r
 
     @defer.inlineCallbacks
     def user_has_pusher(self, user_id):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 92e3b9c512..70ff5751b6 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        current_state_ids = yield context.get_current_state_ids(self)
+        current_state_ids = yield context.get_current_state_ids()
         result = yield self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
index 4219cdd06a..2de50d408c 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
@@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
 DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
 DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
 DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
-DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
 DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
 DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
 
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
new file mode 100644
index 0000000000..c2f557fde9
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This line already existed in deltas/35/device_stream_id but was not included in the
+-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist
+INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS (
+    SELECT * from device_max_stream_id
+);
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
new file mode 100644
index 0000000000..4f24c1405d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
@@ -0,0 +1,29 @@
+/* Copyright 2019 Werner Sembach
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made.
+DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 4ad2929f32..889a9a0ce4 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -975,40 +975,6 @@ CREATE TABLE state_events (
 
 
 
-CREATE TABLE state_group_edges (
-    state_group bigint NOT NULL,
-    prev_state_group bigint NOT NULL
-);
-
-
-
-CREATE SEQUENCE state_group_id_seq
-    START WITH 1
-    INCREMENT BY 1
-    NO MINVALUE
-    NO MAXVALUE
-    CACHE 1;
-
-
-
-CREATE TABLE state_groups (
-    id bigint NOT NULL,
-    room_id text NOT NULL,
-    event_id text NOT NULL
-);
-
-
-
-CREATE TABLE state_groups_state (
-    state_group bigint NOT NULL,
-    room_id text NOT NULL,
-    type text NOT NULL,
-    state_key text NOT NULL,
-    event_id text NOT NULL
-);
-
-
-
 CREATE TABLE stats_stream_pos (
     lock character(1) DEFAULT 'X'::bpchar NOT NULL,
     stream_id bigint,
@@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events
     ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id);
 
 
-
-ALTER TABLE ONLY state_groups
-    ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id);
-
-
-
 ALTER TABLE ONLY stats_stream_pos
     ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock);
 
@@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts);
 
 
 
-CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group);
-
-
-
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group);
-
-
-
-CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key);
-
-
-
 CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering);
 
 
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index bad33291e7..a0411ede7e 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id);
 CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
 CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) );
 CREATE INDEX room_depth_room ON room_depth(room_id);
-CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
-CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL );
 CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) );
 CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) );
 CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) );
@@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL );
 CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT);
 CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id );
 CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id );
-CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL );
-CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group);
 CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
 CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering );
 CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering );
@@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen);
 CREATE INDEX users_creation_ts ON users (creation_ts);
 CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group);
 CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id);
-CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key);
 CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id);
 CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
index c265fd20e2..91d21b2921 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
@@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales
 INSERT INTO user_directory_stream_pos (stream_id) VALUES (0);
 INSERT INTO stats_stream_pos (stream_id) VALUES (0);
 INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
+-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 4eec2fae5e..47ebb8a214 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
@@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = search_query = _parse_query(self.database_engine, search_term)
+        search_query = _parse_query(self.database_engine, search_term)
 
         args = []
 
@@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
@@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = search_query = _parse_query(self.database_engine, search_term)
+        search_query = _parse_query(self.database_engine, search_term)
 
         args = []
 
@@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9ef7b48c74..0dc39f139c 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -17,8 +17,7 @@ import logging
 from collections import namedtuple
 from typing import Iterable, Tuple
 
-from six import iteritems, itervalues
-from six.moves import range
+from six import iteritems
 
 from twisted.internet import defer
 
@@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.database import Database
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
-from synapse.util.caches import get_cache_factor_for, intern_string
+from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.stringutils import to_ascii
 
 logger = logging.getLogger(__name__)
@@ -55,207 +52,14 @@ class _GetStateGroupDelta(
         return len(self.delta_ids) if self.delta_ids else 0
 
 
-class StateGroupBackgroundUpdateStore(SQLBaseStore):
-    """Defines functions related to state groups needed to run the state backgroud
-    updates.
-    """
-
-    def _count_state_group_hops_txn(self, txn, state_group):
-        """Given a state group, count how many hops there are in the tree.
-
-        This is used to ensure the delta chains don't get too long.
-        """
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = """
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT count(*) FROM state;
-            """
-
-            txn.execute(sql, (state_group,))
-            row = txn.fetchone()
-            if row and row[0]:
-                return row[0]
-            else:
-                return 0
-        else:
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
-            count = 0
-
-            while next_group:
-                next_group = self.db.simple_select_one_onecol_txn(
-                    txn,
-                    table="state_group_edges",
-                    keyvalues={"state_group": next_group},
-                    retcol="prev_state_group",
-                    allow_none=True,
-                )
-                if next_group:
-                    count += 1
-
-            return count
-
-    def _get_state_groups_from_groups_txn(
-        self, txn, groups, state_filter=StateFilter.all()
-    ):
-        results = {group: {} for group in groups}
-
-        where_clause, where_args = state_filter.make_sql_filter_clause()
-
-        # Unless the filter clause is empty, we're going to append it after an
-        # existing where clause
-        if where_clause:
-            where_clause = " AND (%s)" % (where_clause,)
-
-        if isinstance(self.database_engine, PostgresEngine):
-            # Temporarily disable sequential scans in this transaction. This is
-            # a temporary hack until we can add the right indices in
-            txn.execute("SET LOCAL enable_seqscan=off")
-
-            # The below query walks the state_group tree so that the "state"
-            # table includes all state_groups in the tree. It then joins
-            # against `state_groups_state` to fetch the latest state.
-            # It assumes that previous state groups are always numerically
-            # lesser.
-            # The PARTITION is used to get the event_id in the greatest state
-            # group for the given type, state_key.
-            # This may return multiple rows per (type, state_key), but last_value
-            # should be the same.
-            sql = """
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT DISTINCT type, state_key, last_value(event_id) OVER (
-                    PARTITION BY type, state_key ORDER BY state_group ASC
-                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
-                ) AS event_id FROM state_groups_state
-                WHERE state_group IN (
-                    SELECT state_group FROM state
-                )
-            """
-
-            for group in groups:
-                args = [group]
-                args.extend(where_args)
-
-                txn.execute(sql + where_clause, args)
-                for row in txn:
-                    typ, state_key, event_id = row
-                    key = (typ, state_key)
-                    results[group][key] = event_id
-        else:
-            max_entries_returned = state_filter.max_entries_returned()
-
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            for group in groups:
-                next_group = group
-
-                while next_group:
-                    # We did this before by getting the list of group ids, and
-                    # then passing that list to sqlite to get latest event for
-                    # each (type, state_key). However, that was terribly slow
-                    # without the right indices (which we can't add until
-                    # after we finish deduping state, which requires this func)
-                    args = [next_group]
-                    args.extend(where_args)
-
-                    txn.execute(
-                        "SELECT type, state_key, event_id FROM state_groups_state"
-                        " WHERE state_group = ? " + where_clause,
-                        args,
-                    )
-                    results[group].update(
-                        ((typ, state_key), event_id)
-                        for typ, state_key, event_id in txn
-                        if (typ, state_key) not in results[group]
-                    )
-
-                    # If the number of entries in the (type,state_key)->event_id dict
-                    # matches the number of (type,state_keys) types we were searching
-                    # for, then we must have found them all, so no need to go walk
-                    # further down the tree... UNLESS our types filter contained
-                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
-                    # search
-                    if (
-                        max_entries_returned is not None
-                        and len(results[group]) == max_entries_returned
-                    ):
-                        break
-
-                    next_group = self.db.simple_select_one_onecol_txn(
-                        txn,
-                        table="state_group_edges",
-                        keyvalues={"state_group": next_group},
-                        retcol="prev_state_group",
-                        allow_none=True,
-                    )
-
-        return results
-
-
 # this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(
-    EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
-):
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers.
     """
 
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
     def __init__(self, database: Database, db_conn, hs):
         super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
 
-        # Originally the state store used a single DictionaryCache to cache the
-        # event IDs for the state types in a given state group to avoid hammering
-        # on the state_group* tables.
-        #
-        # The point of using a DictionaryCache is that it can cache a subset
-        # of the state events for a given state group (i.e. a subset of the keys for a
-        # given dict which is an entry in the cache for a given state group ID).
-        #
-        # However, this poses problems when performing complicated queries
-        # on the store - for instance: "give me all the state for this group, but
-        # limit members to this subset of users", as DictionaryCache's API isn't
-        # rich enough to say "please cache any of these fields, apart from this subset".
-        # This is problematic when lazy loading members, which requires this behaviour,
-        # as without it the cache has no choice but to speculatively load all
-        # state events for the group, which negates the efficiency being sought.
-        #
-        # Rather than overcomplicating DictionaryCache's API, we instead split the
-        # state_group_cache into two halves - one for tracking non-member events,
-        # and the other for tracking member_events.  This means that lazy loading
-        # queries can be made in a cache-friendly manner by querying both caches
-        # separately and then merging the result.  So for the example above, you
-        # would query the members cache for a specific subset of state keys
-        # (which DictionaryCache will handle efficiently and fine) and the non-members
-        # cache for all state (which DictionaryCache will similarly handle fine)
-        # and then just merge the results together.
-        #
-        # We size the non-members cache to be smaller than the members cache as the
-        # vast majority of state in Matrix (today) is member events.
-
-        self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*",
-            # TODO: this hasn't been tuned yet
-            50000 * get_cache_factor_for("stateGroupCache"),
-        )
-        self._state_group_members_cache = DictionaryCache(
-            "*stateGroupMembersCache*",
-            500000 * get_cache_factor_for("stateGroupMembersCache"),
-        )
-
     @defer.inlineCallbacks
     def get_room_version(self, room_id):
         """Get the room_version of a given room
@@ -278,7 +82,7 @@ class StateGroupWorkerStore(
 
     @defer.inlineCallbacks
     def get_room_predecessor(self, room_id):
-        """Get the predecessor room of an upgraded room if one exists.
+        """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
 
         Args:
@@ -291,14 +95,22 @@ class StateGroupWorkerStore(
                     * room_id (str): The room ID of the predecessor room
                     * event_id (str): The ID of the tombstone event in the predecessor room
 
+                None if a predecessor key is not found, or is not a dictionary.
+
         Raises:
-            NotFoundError if the room is unknown
+            NotFoundError if the given room is unknown
         """
         # Retrieve the room's create event
         create_event = yield self.get_create_event_for_room(room_id)
 
-        # Return predecessor if present
-        return create_event.content.get("predecessor", None)
+        # Retrieve the predecessor key of the create event
+        predecessor = create_event.content.get("predecessor", None)
+
+        # Ensure the key is a dictionary
+        if not isinstance(predecessor, dict):
+            return None
+
+        return predecessor
 
     @defer.inlineCallbacks
     def get_create_event_for_room(self, room_id):
@@ -318,7 +130,7 @@ class StateGroupWorkerStore(
 
         # If we can't find the create event, assume we've hit a dead end
         if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id))
+            raise NotFoundError("Unknown room %s" % (room_id,))
 
         # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)
@@ -423,229 +235,6 @@ class StateGroupWorkerStore(
 
         return event.content.get("canonical_alias")
 
-    @cached(max_entries=10000, iterable=True)
-    def get_state_group_delta(self, state_group):
-        """Given a state group try to return a previous group and a delta between
-        the old and the new.
-
-        Returns:
-            (prev_group, delta_ids), where both may be None.
-        """
-
-        def _get_state_group_delta_txn(txn):
-            prev_group = self.db.simple_select_one_onecol_txn(
-                txn,
-                table="state_group_edges",
-                keyvalues={"state_group": state_group},
-                retcol="prev_state_group",
-                allow_none=True,
-            )
-
-            if not prev_group:
-                return _GetStateGroupDelta(None, None)
-
-            delta_ids = self.db.simple_select_list_txn(
-                txn,
-                table="state_groups_state",
-                keyvalues={"state_group": state_group},
-                retcols=("type", "state_key", "event_id"),
-            )
-
-            return _GetStateGroupDelta(
-                prev_group,
-                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
-            )
-
-        return self.db.runInteraction(
-            "get_state_group_delta", _get_state_group_delta_txn
-        )
-
-    @defer.inlineCallbacks
-    def get_state_groups_ids(self, _room_id, event_ids):
-        """Get the event IDs of all the state for the state groups for the given events
-
-        Args:
-            _room_id (str): id of the room for these events
-            event_ids (iterable[str]): ids of the events
-
-        Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
-        """
-        if not event_ids:
-            return {}
-
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
-
-        groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups)
-
-        return group_to_state
-
-    @defer.inlineCallbacks
-    def get_state_ids_for_group(self, state_group):
-        """Get the event IDs of all the state in the given state group
-
-        Args:
-            state_group (int)
-
-        Returns:
-            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
-        """
-        group_to_state = yield self._get_state_for_groups((state_group,))
-
-        return group_to_state[state_group]
-
-    @defer.inlineCallbacks
-    def get_state_groups(self, room_id, event_ids):
-        """ Get the state groups for the given list of event_ids
-
-        Returns:
-            Deferred[dict[int, list[EventBase]]]:
-                dict of state_group_id -> list of state events.
-        """
-        if not event_ids:
-            return {}
-
-        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
-
-        state_event_map = yield self.get_events(
-            [
-                ev_id
-                for group_ids in itervalues(group_to_ids)
-                for ev_id in itervalues(group_ids)
-            ],
-            get_prev_content=False,
-        )
-
-        return {
-            group: [
-                state_event_map[v]
-                for v in itervalues(event_id_map)
-                if v in state_event_map
-            ]
-            for group, event_id_map in iteritems(group_to_ids)
-        }
-
-    @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, state_filter):
-        """Returns the state groups for a given set of groups, filtering on
-        types of state events.
-
-        Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-        Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
-        """
-        results = {}
-
-        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
-        for chunk in chunks:
-            res = yield self.db.runInteraction(
-                "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn,
-                chunk,
-                state_filter,
-            )
-            results.update(res)
-
-        return results
-
-    @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
-        """Given a list of event_ids and type tuples, return a list of state
-        dicts for each event.
-
-        Args:
-            event_ids (list[string])
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
-        """
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
-
-        groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
-        state_event_map = yield self.get_events(
-            [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
-            get_prev_content=False,
-        )
-
-        event_to_state = {
-            event_id: {
-                k: state_event_map[v]
-                for k, v in iteritems(group_to_state[group])
-                if v in state_event_map
-            }
-            for event_id, group in iteritems(event_to_groups)
-        }
-
-        return {event: event_to_state[event] for event in event_ids}
-
-    @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
-        """
-        Get the state dicts corresponding to a list of events, containing the event_ids
-        of the state events (as opposed to the events themselves)
-
-        Args:
-            event_ids(list(str)): events whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            A deferred dict from event_id -> (type, state_key) -> event_id
-        """
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
-
-        groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
-        event_to_state = {
-            event_id: group_to_state[group]
-            for event_id, group in iteritems(event_to_groups)
-        }
-
-        return {event: event_to_state[event] for event in event_ids}
-
-    @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
-        """
-        Get the state dict corresponding to a particular event
-
-        Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            A deferred dict from (type, state_key) -> state_event
-        """
-        state_map = yield self.get_state_for_events([event_id], state_filter)
-        return state_map[event_id]
-
-    @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
-        """
-        Get the state dict corresponding to a particular event
-
-        Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            A deferred dict from (type, state_key) -> state_event
-        """
-        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
-        return state_map[event_id]
-
     @cached(max_entries=50000)
     def _get_state_group_for_event(self, event_id):
         return self.db.simple_select_one_onecol(
@@ -676,329 +265,6 @@ class StateGroupWorkerStore(
 
         return {row["event_id"]: row["state_group"] for row in rows}
 
-    def _get_state_for_group_using_cache(self, cache, group, state_filter):
-        """Checks if group is in cache. See `_get_state_for_groups`
-
-        Args:
-            cache(DictionaryCache): the state group cache to use
-            group(int): The state group to lookup
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns 2-tuple (`state_dict`, `got_all`).
-        `got_all` is a bool indicating if we successfully retrieved all
-        requests state from the cache, if False we need to query the DB for the
-        missing state.
-        """
-        is_all, known_absent, state_dict_ids = cache.get(group)
-
-        if is_all or state_filter.is_full():
-            # Either we have everything or want everything, either way
-            # `is_all` tells us whether we've gotten everything.
-            return state_filter.filter_state(state_dict_ids), is_all
-
-        # tracks whether any of our requested types are missing from the cache
-        missing_types = False
-
-        if state_filter.has_wildcards():
-            # We don't know if we fetched all the state keys for the types in
-            # the filter that are wildcards, so we have to assume that we may
-            # have missed some.
-            missing_types = True
-        else:
-            # There aren't any wild cards, so `concrete_types()` returns the
-            # complete list of event types we're wanting.
-            for key in state_filter.concrete_types():
-                if key not in state_dict_ids and key not in known_absent:
-                    missing_types = True
-                    break
-
-        return state_filter.filter_state(state_dict_ids), not missing_types
-
-    @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
-        """Gets the state at each of a list of state groups, optionally
-        filtering by type/state_key
-
-        Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-        Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
-        """
-
-        member_filter, non_member_filter = state_filter.get_member_split()
-
-        # Now we look them up in the member and non-member caches
-        (
-            non_member_state,
-            incomplete_groups_nm,
-        ) = yield self._get_state_for_groups_using_cache(
-            groups, self._state_group_cache, state_filter=non_member_filter
-        )
-
-        (
-            member_state,
-            incomplete_groups_m,
-        ) = yield self._get_state_for_groups_using_cache(
-            groups, self._state_group_members_cache, state_filter=member_filter
-        )
-
-        state = dict(non_member_state)
-        for group in groups:
-            state[group].update(member_state[group])
-
-        # Now fetch any missing groups from the database
-
-        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
-
-        if not incomplete_groups:
-            return state
-
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
-
-        # Help the cache hit ratio by expanding the filter a bit
-        db_state_filter = state_filter.return_expanded()
-
-        group_to_state_dict = yield self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
-
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
-        )
-
-        # And finally update the result dict, by filtering out any extra
-        # stuff we pulled out of the database.
-        for group, group_state_dict in iteritems(group_to_state_dict):
-            # We just replace any existing entries, as we will have loaded
-            # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
-
-        return state
-
-    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
-        """Gets the state at each of a list of state groups, optionally
-        filtering by type/state_key, querying from a specific cache.
-
-        Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            cache (DictionaryCache): the cache of group ids to state dicts which
-                we will pass through - either the normal state cache or the specific
-                members state cache.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
-            dict of state_group_id -> (dict of (type, state_key) -> event id)
-            of entries in the cache, and the state group ids either missing
-            from the cache or incomplete.
-        """
-        results = {}
-        incomplete_groups = set()
-        for group in set(groups):
-            state_dict_ids, got_all = self._get_state_for_group_using_cache(
-                cache, group, state_filter
-            )
-            results[group] = state_dict_ids
-
-            if not got_all:
-                incomplete_groups.add(group)
-
-        return results, incomplete_groups
-
-    def _insert_into_cache(
-        self,
-        group_to_state_dict,
-        state_filter,
-        cache_seq_num_members,
-        cache_seq_num_non_members,
-    ):
-        """Inserts results from querying the database into the relevant cache.
-
-        Args:
-            group_to_state_dict (dict): The new entries pulled from database.
-                Map from state group to state dict
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-            cache_seq_num_members (int): Sequence number of member cache since
-                last lookup in cache
-            cache_seq_num_non_members (int): Sequence number of member cache since
-                last lookup in cache
-        """
-
-        # We need to work out which types we've fetched from the DB for the
-        # member vs non-member caches. This should be as accurate as possible,
-        # but can be an underestimate (e.g. when we have wild cards)
-
-        member_filter, non_member_filter = state_filter.get_member_split()
-        if member_filter.is_full():
-            # We fetched all member events
-            member_types = None
-        else:
-            # `concrete_types()` will only return a subset when there are wild
-            # cards in the filter, but that's fine.
-            member_types = member_filter.concrete_types()
-
-        if non_member_filter.is_full():
-            # We fetched all non member events
-            non_member_types = None
-        else:
-            non_member_types = non_member_filter.concrete_types()
-
-        for group, group_state_dict in iteritems(group_to_state_dict):
-            state_dict_members = {}
-            state_dict_non_members = {}
-
-            for k, v in iteritems(group_state_dict):
-                if k[0] == EventTypes.Member:
-                    state_dict_members[k] = v
-                else:
-                    state_dict_non_members[k] = v
-
-            self._state_group_members_cache.update(
-                cache_seq_num_members,
-                key=group,
-                value=state_dict_members,
-                fetched_keys=member_types,
-            )
-
-            self._state_group_cache.update(
-                cache_seq_num_non_members,
-                key=group,
-                value=state_dict_non_members,
-                fetched_keys=non_member_types,
-            )
-
-    def store_state_group(
-        self, event_id, room_id, prev_group, delta_ids, current_state_ids
-    ):
-        """Store a new set of state, returning a newly assigned state group.
-
-        Args:
-            event_id (str): The event ID for which the state was calculated
-            room_id (str)
-            prev_group (int|None): A previous state group for the room, optional.
-            delta_ids (dict|None): The delta between state at `prev_group` and
-                `current_state_ids`, if `prev_group` was given. Same format as
-                `current_state_ids`.
-            current_state_ids (dict): The state to store. Map of (type, state_key)
-                to event_id.
-
-        Returns:
-            Deferred[int]: The state group ID
-        """
-
-        def _store_state_group_txn(txn):
-            if current_state_ids is None:
-                # AFAIK, this can never happen
-                raise Exception("current_state_ids cannot be None")
-
-            state_group = self.database_engine.get_next_state_group_id(txn)
-
-            self.db.simple_insert_txn(
-                txn,
-                table="state_groups",
-                values={"id": state_group, "room_id": room_id, "event_id": event_id},
-            )
-
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if prev_group:
-                is_in_db = self.db.simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                self.db.simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={"state_group": state_group, "prev_state_group": prev_group},
-                )
-
-                self.db.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in iteritems(delta_ids)
-                    ],
-                )
-            else:
-                self.db.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in iteritems(current_state_ids)
-                    ],
-                )
-
-            # Prefill the state group caches with this group.
-            # It's fine to use the sequence like this as the state group map
-            # is immutable. (If the map wasn't immutable then this prefill could
-            # race with another update)
-
-            current_member_state_ids = {
-                s: ev
-                for (s, ev) in iteritems(current_state_ids)
-                if s[0] == EventTypes.Member
-            }
-            txn.call_after(
-                self._state_group_members_cache.update,
-                self._state_group_members_cache.sequence,
-                key=state_group,
-                value=dict(current_member_state_ids),
-            )
-
-            current_non_member_state_ids = {
-                s: ev
-                for (s, ev) in iteritems(current_state_ids)
-                if s[0] != EventTypes.Member
-            }
-            txn.call_after(
-                self._state_group_cache.update,
-                self._state_group_cache.sequence,
-                key=state_group,
-                value=dict(current_non_member_state_ids),
-            )
-
-            return state_group
-
-        return self.db.runInteraction("store_state_group", _store_state_group_txn)
-
     @defer.inlineCallbacks
     def get_referenced_state_groups(self, state_groups):
         """Check if the state groups are referenced by events.
@@ -1023,22 +289,14 @@ class StateGroupWorkerStore(
         return set(row["state_group"] for row in rows)
 
 
-class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+class MainStateBackgroundUpdateStore(SQLBaseStore):
 
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
 
     def __init__(self, database: Database, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
-        self.db.updates.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
-        )
-        self.db.updates.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
-        )
+        super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
         self.db.updates.register_background_index_update(
             self.CURRENT_STATE_INDEX_UPDATE_NAME,
             index_name="current_state_events_member_index",
@@ -1053,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             columns=["state_group"],
         )
 
-    @defer.inlineCallbacks
-    def _background_deduplicate_state(self, progress, batch_size):
-        """This background update will slowly deduplicate state by reencoding
-        them as deltas.
-        """
-        last_state_group = progress.get("last_state_group", 0)
-        rows_inserted = progress.get("rows_inserted", 0)
-        max_group = progress.get("max_group", None)
-
-        BATCH_SIZE_SCALE_FACTOR = 100
 
-        batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
-
-        if max_group is None:
-            rows = yield self.db.execute(
-                "_background_deduplicate_state",
-                None,
-                "SELECT coalesce(max(id), 0) FROM state_groups",
-            )
-            max_group = rows[0][0]
-
-        def reindex_txn(txn):
-            new_last_state_group = last_state_group
-            for count in range(batch_size):
-                txn.execute(
-                    "SELECT id, room_id FROM state_groups"
-                    " WHERE ? < id AND id <= ?"
-                    " ORDER BY id ASC"
-                    " LIMIT 1",
-                    (new_last_state_group, max_group),
-                )
-                row = txn.fetchone()
-                if row:
-                    state_group, room_id = row
-
-                if not row or not state_group:
-                    return True, count
-
-                txn.execute(
-                    "SELECT state_group FROM state_group_edges"
-                    " WHERE state_group = ?",
-                    (state_group,),
-                )
-
-                # If we reach a point where we've already started inserting
-                # edges we should stop.
-                if txn.fetchall():
-                    return True, count
-
-                txn.execute(
-                    "SELECT coalesce(max(id), 0) FROM state_groups"
-                    " WHERE id < ? AND room_id = ?",
-                    (state_group, room_id),
-                )
-                (prev_group,) = txn.fetchone()
-                new_last_state_group = state_group
-
-                if prev_group:
-                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-                    if potential_hops >= MAX_STATE_DELTA_HOPS:
-                        # We want to ensure chains are at most this long,#
-                        # otherwise read performance degrades.
-                        continue
-
-                    prev_state = self._get_state_groups_from_groups_txn(
-                        txn, [prev_group]
-                    )
-                    prev_state = prev_state[prev_group]
-
-                    curr_state = self._get_state_groups_from_groups_txn(
-                        txn, [state_group]
-                    )
-                    curr_state = curr_state[state_group]
-
-                    if not set(prev_state.keys()) - set(curr_state.keys()):
-                        # We can only do a delta if the current has a strict super set
-                        # of keys
-
-                        delta_state = {
-                            key: value
-                            for key, value in iteritems(curr_state)
-                            if prev_state.get(key, None) != value
-                        }
-
-                        self.db.simple_delete_txn(
-                            txn,
-                            table="state_group_edges",
-                            keyvalues={"state_group": state_group},
-                        )
-
-                        self.db.simple_insert_txn(
-                            txn,
-                            table="state_group_edges",
-                            values={
-                                "state_group": state_group,
-                                "prev_state_group": prev_group,
-                            },
-                        )
-
-                        self.db.simple_delete_txn(
-                            txn,
-                            table="state_groups_state",
-                            keyvalues={"state_group": state_group},
-                        )
-
-                        self.db.simple_insert_many_txn(
-                            txn,
-                            table="state_groups_state",
-                            values=[
-                                {
-                                    "state_group": state_group,
-                                    "room_id": room_id,
-                                    "type": key[0],
-                                    "state_key": key[1],
-                                    "event_id": state_id,
-                                }
-                                for key, state_id in iteritems(delta_state)
-                            ],
-                        )
-
-            progress = {
-                "last_state_group": state_group,
-                "rows_inserted": rows_inserted + batch_size,
-                "max_group": max_group,
-            }
-
-            self.db.updates._background_update_progress_txn(
-                txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
-            )
-
-            return False, batch_size
-
-        finished, result = yield self.db.runInteraction(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
-        )
-
-        if finished:
-            yield self.db.updates._end_background_update(
-                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
-            )
-
-        return result * BATCH_SIZE_SCALE_FACTOR
-
-    @defer.inlineCallbacks
-    def _background_index_state(self, progress, batch_size):
-        def reindex_txn(conn):
-            conn.rollback()
-            if isinstance(self.database_engine, PostgresEngine):
-                # postgres insists on autocommit for the index
-                conn.set_session(autocommit=True)
-                try:
-                    txn = conn.cursor()
-                    txn.execute(
-                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
-                        " ON state_groups_state(state_group, type, state_key)"
-                    )
-                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-                finally:
-                    conn.set_session(autocommit=False)
-            else:
-                txn = conn.cursor()
-                txn.execute(
-                    "CREATE INDEX state_groups_state_type_idx"
-                    " ON state_groups_state(state_group, type, state_key)"
-                )
-                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-
-        yield self.db.runWithConnection(reindex_txn)
-
-        yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
-
-        return 1
-
-
-class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
+class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
     """ Keeps track of the state at a given event.
 
     This is done by the concept of `state groups`. Every event is a assigned
diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py
new file mode 100644
index 0000000000..86e09f6229
--- /dev/null
+++ b/synapse/storage/data_stores/state/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.data_stores.state.store import StateGroupDataStore  # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
new file mode 100644
index 0000000000..e8edaf9f7b
--- /dev/null
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+    """Defines functions related to state groups needed to run the state backgroud
+    updates.
+    """
+
+    def _count_state_group_hops_txn(self, txn, state_group):
+        """Given a state group, count how many hops there are in the tree.
+
+        This is used to ensure the delta chains don't get too long.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = """
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self.db.simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
+    def _get_state_groups_from_groups_txn(
+        self, txn, groups, state_filter=StateFilter.all()
+    ):
+        results = {group: {} for group in groups}
+
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        # Unless the filter clause is empty, we're going to append it after an
+        # existing where clause
+        if where_clause:
+            where_clause = " AND (%s)" % (where_clause,)
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # Temporarily disable sequential scans in this transaction. This is
+            # a temporary hack until we can add the right indices in
+            txn.execute("SET LOCAL enable_seqscan=off")
+
+            # The below query walks the state_group tree so that the "state"
+            # table includes all state_groups in the tree. It then joins
+            # against `state_groups_state` to fetch the latest state.
+            # It assumes that previous state groups are always numerically
+            # lesser.
+            # The PARTITION is used to get the event_id in the greatest state
+            # group for the given type, state_key.
+            # This may return multiple rows per (type, state_key), but last_value
+            # should be the same.
+            sql = """
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT DISTINCT type, state_key, last_value(event_id) OVER (
+                    PARTITION BY type, state_key ORDER BY state_group ASC
+                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+                ) AS event_id FROM state_groups_state
+                WHERE state_group IN (
+                    SELECT state_group FROM state
+                )
+            """
+
+            for group in groups:
+                args = [group]
+                args.extend(where_args)
+
+                txn.execute(sql + where_clause, args)
+                for row in txn:
+                    typ, state_key, event_id = row
+                    key = (typ, state_key)
+                    results[group][key] = event_id
+        else:
+            max_entries_returned = state_filter.max_entries_returned()
+
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            for group in groups:
+                next_group = group
+
+                while next_group:
+                    # We did this before by getting the list of group ids, and
+                    # then passing that list to sqlite to get latest event for
+                    # each (type, state_key). However, that was terribly slow
+                    # without the right indices (which we can't add until
+                    # after we finish deduping state, which requires this func)
+                    args = [next_group]
+                    args.extend(where_args)
+
+                    txn.execute(
+                        "SELECT type, state_key, event_id FROM state_groups_state"
+                        " WHERE state_group = ? " + where_clause,
+                        args,
+                    )
+                    results[group].update(
+                        ((typ, state_key), event_id)
+                        for typ, state_key, event_id in txn
+                        if (typ, state_key) not in results[group]
+                    )
+
+                    # If the number of entries in the (type,state_key)->event_id dict
+                    # matches the number of (type,state_keys) types we were searching
+                    # for, then we must have found them all, so no need to go walk
+                    # further down the tree... UNLESS our types filter contained
+                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
+                    # search
+                    if (
+                        max_entries_returned is not None
+                        and len(results[group]) == max_entries_returned
+                    ):
+                        break
+
+                    next_group = self.db.simple_select_one_onecol_txn(
+                        txn,
+                        table="state_group_edges",
+                        keyvalues={"state_group": next_group},
+                        retcol="prev_state_group",
+                        allow_none=True,
+                    )
+
+        return results
+
+
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        self.db.updates.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.db.updates.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
+        )
+        self.db.updates.register_background_index_update(
+            self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
+            index_name="state_groups_room_id_idx",
+            table="state_groups",
+            columns=["room_id"],
+        )
+
+    @defer.inlineCallbacks
+    def _background_deduplicate_state(self, progress, batch_size):
+        """This background update will slowly deduplicate state by reencoding
+        them as deltas.
+        """
+        last_state_group = progress.get("last_state_group", 0)
+        rows_inserted = progress.get("rows_inserted", 0)
+        max_group = progress.get("max_group", None)
+
+        BATCH_SIZE_SCALE_FACTOR = 100
+
+        batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
+
+        if max_group is None:
+            rows = yield self.db.execute(
+                "_background_deduplicate_state",
+                None,
+                "SELECT coalesce(max(id), 0) FROM state_groups",
+            )
+            max_group = rows[0][0]
+
+        def reindex_txn(txn):
+            new_last_state_group = last_state_group
+            for count in range(batch_size):
+                txn.execute(
+                    "SELECT id, room_id FROM state_groups"
+                    " WHERE ? < id AND id <= ?"
+                    " ORDER BY id ASC"
+                    " LIMIT 1",
+                    (new_last_state_group, max_group),
+                )
+                row = txn.fetchone()
+                if row:
+                    state_group, room_id = row
+
+                if not row or not state_group:
+                    return True, count
+
+                txn.execute(
+                    "SELECT state_group FROM state_group_edges"
+                    " WHERE state_group = ?",
+                    (state_group,),
+                )
+
+                # If we reach a point where we've already started inserting
+                # edges we should stop.
+                if txn.fetchall():
+                    return True, count
+
+                txn.execute(
+                    "SELECT coalesce(max(id), 0) FROM state_groups"
+                    " WHERE id < ? AND room_id = ?",
+                    (state_group, room_id),
+                )
+                (prev_group,) = txn.fetchone()
+                new_last_state_group = state_group
+
+                if prev_group:
+                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+                    if potential_hops >= MAX_STATE_DELTA_HOPS:
+                        # We want to ensure chains are at most this long,#
+                        # otherwise read performance degrades.
+                        continue
+
+                    prev_state = self._get_state_groups_from_groups_txn(
+                        txn, [prev_group]
+                    )
+                    prev_state = prev_state[prev_group]
+
+                    curr_state = self._get_state_groups_from_groups_txn(
+                        txn, [state_group]
+                    )
+                    curr_state = curr_state[state_group]
+
+                    if not set(prev_state.keys()) - set(curr_state.keys()):
+                        # We can only do a delta if the current has a strict super set
+                        # of keys
+
+                        delta_state = {
+                            key: value
+                            for key, value in iteritems(curr_state)
+                            if prev_state.get(key, None) != value
+                        }
+
+                        self.db.simple_delete_txn(
+                            txn,
+                            table="state_group_edges",
+                            keyvalues={"state_group": state_group},
+                        )
+
+                        self.db.simple_insert_txn(
+                            txn,
+                            table="state_group_edges",
+                            values={
+                                "state_group": state_group,
+                                "prev_state_group": prev_group,
+                            },
+                        )
+
+                        self.db.simple_delete_txn(
+                            txn,
+                            table="state_groups_state",
+                            keyvalues={"state_group": state_group},
+                        )
+
+                        self.db.simple_insert_many_txn(
+                            txn,
+                            table="state_groups_state",
+                            values=[
+                                {
+                                    "state_group": state_group,
+                                    "room_id": room_id,
+                                    "type": key[0],
+                                    "state_key": key[1],
+                                    "event_id": state_id,
+                                }
+                                for key, state_id in iteritems(delta_state)
+                            ],
+                        )
+
+            progress = {
+                "last_state_group": state_group,
+                "rows_inserted": rows_inserted + batch_size,
+                "max_group": max_group,
+            }
+
+            self.db.updates._background_update_progress_txn(
+                txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
+            )
+
+            return False, batch_size
+
+        finished, result = yield self.db.runInteraction(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+        )
+
+        if finished:
+            yield self.db.updates._end_background_update(
+                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+            )
+
+        return result * BATCH_SIZE_SCALE_FACTOR
+
+    @defer.inlineCallbacks
+    def _background_index_state(self, progress, batch_size):
+        def reindex_txn(conn):
+            conn.rollback()
+            if isinstance(self.database_engine, PostgresEngine):
+                # postgres insists on autocommit for the index
+                conn.set_session(autocommit=True)
+                try:
+                    txn = conn.cursor()
+                    txn.execute(
+                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+                        " ON state_groups_state(state_group, type, state_key)"
+                    )
+                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+                finally:
+                    conn.set_session(autocommit=False)
+            else:
+                txn = conn.cursor()
+                txn.execute(
+                    "CREATE INDEX state_groups_state_type_idx"
+                    " ON state_groups_state(state_group, type, state_key)"
+                )
+                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+
+        yield self.db.runWithConnection(reindex_txn)
+
+        yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+        return 1
diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql
+++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
new file mode 100644
index 0000000000..1450313bfa
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- The following indices are redundant, other indices are equivalent or
+-- supersets
+DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
index 33980d02f0..33980d02f0 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/state.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
index 9fd1ccf6f7..9fd1ccf6f7 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
new file mode 100644
index 0000000000..7916ef18b2
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+    ('state_groups_room_id_idx', '{}');
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..35f97d6b3d
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
@@ -0,0 +1,37 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_groups (
+    id BIGINT PRIMARY KEY,
+    room_id TEXT NOT NULL,
+    event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_groups_state (
+    state_group BIGINT NOT NULL,
+    room_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    state_key TEXT NOT NULL,
+    event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_group_edges (
+    state_group BIGINT NOT NULL,
+    prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges (state_group);
+CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group);
+CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key);
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
new file mode 100644
index 0000000000..fcd926c9fb
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE state_group_id_seq
+    START WITH 1
+    INCREMENT BY 1
+    NO MINVALUE
+    NO MAXVALUE
+    CACHE 1;
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
new file mode 100644
index 0000000000..d53695f238
--- /dev/null
+++ b/synapse/storage/data_stores/state/store.py
@@ -0,0 +1,640 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import namedtuple
+
+from six import iteritems
+from six.moves import range
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.util.caches import get_cache_factor_for
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+    """Return type of get_state_group_delta that implements __len__, which lets
+    us use the itrable flag when caching
+    """
+
+    __slots__ = []
+
+    def __len__(self):
+        return len(self.delta_ids) if self.delta_ids else 0
+
+
+class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
+    """A data store for fetching/storing state groups.
+    """
+
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+
+        # Originally the state store used a single DictionaryCache to cache the
+        # event IDs for the state types in a given state group to avoid hammering
+        # on the state_group* tables.
+        #
+        # The point of using a DictionaryCache is that it can cache a subset
+        # of the state events for a given state group (i.e. a subset of the keys for a
+        # given dict which is an entry in the cache for a given state group ID).
+        #
+        # However, this poses problems when performing complicated queries
+        # on the store - for instance: "give me all the state for this group, but
+        # limit members to this subset of users", as DictionaryCache's API isn't
+        # rich enough to say "please cache any of these fields, apart from this subset".
+        # This is problematic when lazy loading members, which requires this behaviour,
+        # as without it the cache has no choice but to speculatively load all
+        # state events for the group, which negates the efficiency being sought.
+        #
+        # Rather than overcomplicating DictionaryCache's API, we instead split the
+        # state_group_cache into two halves - one for tracking non-member events,
+        # and the other for tracking member_events.  This means that lazy loading
+        # queries can be made in a cache-friendly manner by querying both caches
+        # separately and then merging the result.  So for the example above, you
+        # would query the members cache for a specific subset of state keys
+        # (which DictionaryCache will handle efficiently and fine) and the non-members
+        # cache for all state (which DictionaryCache will similarly handle fine)
+        # and then just merge the results together.
+        #
+        # We size the non-members cache to be smaller than the members cache as the
+        # vast majority of state in Matrix (today) is member events.
+
+        self._state_group_cache = DictionaryCache(
+            "*stateGroupCache*",
+            # TODO: this hasn't been tuned yet
+            50000 * get_cache_factor_for("stateGroupCache"),
+        )
+        self._state_group_members_cache = DictionaryCache(
+            "*stateGroupMembersCache*",
+            500000 * get_cache_factor_for("stateGroupMembersCache"),
+        )
+
+    @cached(max_entries=10000, iterable=True)
+    def get_state_group_delta(self, state_group):
+        """Given a state group try to return a previous group and a delta between
+        the old and the new.
+
+        Returns:
+            (prev_group, delta_ids), where both may be None.
+        """
+
+        def _get_state_group_delta_txn(txn):
+            prev_group = self.db.simple_select_one_onecol_txn(
+                txn,
+                table="state_group_edges",
+                keyvalues={"state_group": state_group},
+                retcol="prev_state_group",
+                allow_none=True,
+            )
+
+            if not prev_group:
+                return _GetStateGroupDelta(None, None)
+
+            delta_ids = self.db.simple_select_list_txn(
+                txn,
+                table="state_groups_state",
+                keyvalues={"state_group": state_group},
+                retcols=("type", "state_key", "event_id"),
+            )
+
+            return _GetStateGroupDelta(
+                prev_group,
+                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+            )
+
+        return self.db.runInteraction(
+            "get_state_group_delta", _get_state_group_delta_txn
+        )
+
+    @defer.inlineCallbacks
+    def _get_state_groups_from_groups(self, groups, state_filter):
+        """Returns the state groups for a given set of groups, filtering on
+        types of state events.
+
+        Args:
+            groups(list[int]): list of state group IDs to query
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+        results = {}
+
+        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
+        for chunk in chunks:
+            res = yield self.db.runInteraction(
+                "_get_state_groups_from_groups",
+                self._get_state_groups_from_groups_txn,
+                chunk,
+                state_filter,
+            )
+            results.update(res)
+
+        return results
+
+    def _get_state_for_group_using_cache(self, cache, group, state_filter):
+        """Checks if group is in cache. See `_get_state_for_groups`
+
+        Args:
+            cache(DictionaryCache): the state group cache to use
+            group(int): The state group to lookup
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns 2-tuple (`state_dict`, `got_all`).
+        `got_all` is a bool indicating if we successfully retrieved all
+        requests state from the cache, if False we need to query the DB for the
+        missing state.
+        """
+        is_all, known_absent, state_dict_ids = cache.get(group)
+
+        if is_all or state_filter.is_full():
+            # Either we have everything or want everything, either way
+            # `is_all` tells us whether we've gotten everything.
+            return state_filter.filter_state(state_dict_ids), is_all
+
+        # tracks whether any of our requested types are missing from the cache
+        missing_types = False
+
+        if state_filter.has_wildcards():
+            # We don't know if we fetched all the state keys for the types in
+            # the filter that are wildcards, so we have to assume that we may
+            # have missed some.
+            missing_types = True
+        else:
+            # There aren't any wild cards, so `concrete_types()` returns the
+            # complete list of event types we're wanting.
+            for key in state_filter.concrete_types():
+                if key not in state_dict_ids and key not in known_absent:
+                    missing_types = True
+                    break
+
+        return state_filter.filter_state(state_dict_ids), not missing_types
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+
+        member_filter, non_member_filter = state_filter.get_member_split()
+
+        # Now we look them up in the member and non-member caches
+        (
+            non_member_state,
+            incomplete_groups_nm,
+        ) = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_cache, state_filter=non_member_filter
+        )
+
+        (
+            member_state,
+            incomplete_groups_m,
+        ) = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_members_cache, state_filter=member_filter
+        )
+
+        state = dict(non_member_state)
+        for group in groups:
+            state[group].update(member_state[group])
+
+        # Now fetch any missing groups from the database
+
+        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+        if not incomplete_groups:
+            return state
+
+        cache_sequence_nm = self._state_group_cache.sequence
+        cache_sequence_m = self._state_group_members_cache.sequence
+
+        # Help the cache hit ratio by expanding the filter a bit
+        db_state_filter = state_filter.return_expanded()
+
+        group_to_state_dict = yield self._get_state_groups_from_groups(
+            list(incomplete_groups), state_filter=db_state_filter
+        )
+
+        # Now lets update the caches
+        self._insert_into_cache(
+            group_to_state_dict,
+            db_state_filter,
+            cache_seq_num_members=cache_sequence_m,
+            cache_seq_num_non_members=cache_sequence_nm,
+        )
+
+        # And finally update the result dict, by filtering out any extra
+        # stuff we pulled out of the database.
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            # We just replace any existing entries, as we will have loaded
+            # everything we need from the database anyway.
+            state[group] = state_filter.filter_state(group_state_dict)
+
+        return state
+
+    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key, querying from a specific cache.
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            cache (DictionaryCache): the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the specific
+                members state cache.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
+            dict of state_group_id -> (dict of (type, state_key) -> event id)
+            of entries in the cache, and the state group ids either missing
+            from the cache or incomplete.
+        """
+        results = {}
+        incomplete_groups = set()
+        for group in set(groups):
+            state_dict_ids, got_all = self._get_state_for_group_using_cache(
+                cache, group, state_filter
+            )
+            results[group] = state_dict_ids
+
+            if not got_all:
+                incomplete_groups.add(group)
+
+        return results, incomplete_groups
+
+    def _insert_into_cache(
+        self,
+        group_to_state_dict,
+        state_filter,
+        cache_seq_num_members,
+        cache_seq_num_non_members,
+    ):
+        """Inserts results from querying the database into the relevant cache.
+
+        Args:
+            group_to_state_dict (dict): The new entries pulled from database.
+                Map from state group to state dict
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+            cache_seq_num_members (int): Sequence number of member cache since
+                last lookup in cache
+            cache_seq_num_non_members (int): Sequence number of member cache since
+                last lookup in cache
+        """
+
+        # We need to work out which types we've fetched from the DB for the
+        # member vs non-member caches. This should be as accurate as possible,
+        # but can be an underestimate (e.g. when we have wild cards)
+
+        member_filter, non_member_filter = state_filter.get_member_split()
+        if member_filter.is_full():
+            # We fetched all member events
+            member_types = None
+        else:
+            # `concrete_types()` will only return a subset when there are wild
+            # cards in the filter, but that's fine.
+            member_types = member_filter.concrete_types()
+
+        if non_member_filter.is_full():
+            # We fetched all non member events
+            non_member_types = None
+        else:
+            non_member_types = non_member_filter.concrete_types()
+
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            state_dict_members = {}
+            state_dict_non_members = {}
+
+            for k, v in iteritems(group_state_dict):
+                if k[0] == EventTypes.Member:
+                    state_dict_members[k] = v
+                else:
+                    state_dict_non_members[k] = v
+
+            self._state_group_members_cache.update(
+                cache_seq_num_members,
+                key=group,
+                value=state_dict_members,
+                fetched_keys=member_types,
+            )
+
+            self._state_group_cache.update(
+                cache_seq_num_non_members,
+                key=group,
+                value=state_dict_non_members,
+                fetched_keys=non_member_types,
+            )
+
+    def store_state_group(
+        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+    ):
+        """Store a new set of state, returning a newly assigned state group.
+
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
+
+        Returns:
+            Deferred[int]: The state group ID
+        """
+
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
+                # AFAIK, this can never happen
+                raise Exception("current_state_ids cannot be None")
+
+            state_group = self.database_engine.get_next_state_group_id(txn)
+
+            self.db.simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
+            )
+
+            # We persist as a delta if we can, while also ensuring the chain
+            # of deltas isn't tooo long, as otherwise read performance degrades.
+            if prev_group:
+                is_in_db = self.db.simple_select_one_onecol_txn(
+                    txn,
+                    table="state_groups",
+                    keyvalues={"id": prev_group},
+                    retcol="id",
+                    allow_none=True,
+                )
+                if not is_in_db:
+                    raise Exception(
+                        "Trying to persist state with unpersisted prev_group: %r"
+                        % (prev_group,)
+                    )
+
+                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+                self.db.simple_insert_txn(
+                    txn,
+                    table="state_group_edges",
+                    values={"state_group": state_group, "prev_state_group": prev_group},
+                )
+
+                self.db.simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(delta_ids)
+                    ],
+                )
+            else:
+                self.db.simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(current_state_ids)
+                    ],
+                )
+
+            # Prefill the state group caches with this group.
+            # It's fine to use the sequence like this as the state group map
+            # is immutable. (If the map wasn't immutable then this prefill could
+            # race with another update)
+
+            current_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] == EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_members_cache.update,
+                self._state_group_members_cache.sequence,
+                key=state_group,
+                value=dict(current_member_state_ids),
+            )
+
+            current_non_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] != EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_cache.update,
+                self._state_group_cache.sequence,
+                key=state_group,
+                value=dict(current_non_member_state_ids),
+            )
+
+            return state_group
+
+        return self.db.runInteraction("store_state_group", _store_state_group_txn)
+
+    def purge_unreferenced_state_groups(
+        self, room_id: str, state_groups_to_delete
+    ) -> defer.Deferred:
+        """Deletes no longer referenced state groups and de-deltas any state
+        groups that reference them.
+
+        Args:
+            room_id: The room the state groups belong to (must all be in the
+                same room).
+            state_groups_to_delete (Collection[int]): Set of all state groups
+                to delete.
+        """
+
+        return self.db.runInteraction(
+            "purge_unreferenced_state_groups",
+            self._purge_unreferenced_state_groups,
+            room_id,
+            state_groups_to_delete,
+        )
+
+    def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+        logger.info(
+            "[purge] found %i state groups to delete", len(state_groups_to_delete)
+        )
+
+        rows = self.db.simple_select_many_txn(
+            txn,
+            table="state_group_edges",
+            column="prev_state_group",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+            retcols=("state_group",),
+        )
+
+        remaining_state_groups = set(
+            row["state_group"]
+            for row in rows
+            if row["state_group"] not in state_groups_to_delete
+        )
+
+        logger.info(
+            "[purge] de-delta-ing %i remaining state groups",
+            len(remaining_state_groups),
+        )
+
+        # Now we turn the state groups that reference to-be-deleted state
+        # groups to non delta versions.
+        for sg in remaining_state_groups:
+            logger.info("[purge] de-delta-ing remaining state group %s", sg)
+            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
+            curr_state = curr_state[sg]
+
+            self.db.simple_delete_txn(
+                txn, table="state_groups_state", keyvalues={"state_group": sg}
+            )
+
+            self.db.simple_delete_txn(
+                txn, table="state_group_edges", keyvalues={"state_group": sg}
+            )
+
+            self.db.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                values=[
+                    {
+                        "state_group": sg,
+                        "room_id": room_id,
+                        "type": key[0],
+                        "state_key": key[1],
+                        "event_id": state_id,
+                    }
+                    for key, state_id in iteritems(curr_state)
+                ],
+            )
+
+        logger.info("[purge] removing redundant state groups")
+        txn.executemany(
+            "DELETE FROM state_groups_state WHERE state_group = ?",
+            ((sg,) for sg in state_groups_to_delete),
+        )
+        txn.executemany(
+            "DELETE FROM state_groups WHERE id = ?",
+            ((sg,) for sg in state_groups_to_delete),
+        )
+
+    @defer.inlineCallbacks
+    def get_previous_state_groups(self, state_groups):
+        """Fetch the previous groups of the given state groups.
+
+        Args:
+            state_groups (Iterable[int])
+
+        Returns:
+            Deferred[dict[int, int]]: mapping from state group to previous
+            state group.
+        """
+
+        rows = yield self.db.simple_select_many_batch(
+            table="state_group_edges",
+            column="prev_state_group",
+            iterable=state_groups,
+            keyvalues={},
+            retcols=("prev_state_group", "state_group"),
+            desc="get_previous_state_groups",
+        )
+
+        return {row["state_group"]: row["prev_state_group"] for row in rows}
+
+    def purge_room_state(self, room_id, state_groups_to_delete):
+        """Deletes all record of a room from state tables
+
+        Args:
+            room_id (str):
+            state_groups_to_delete (list[int]): State groups to delete
+        """
+
+        return self.db.runInteraction(
+            "purge_room_state",
+            self._purge_room_state_txn,
+            room_id,
+            state_groups_to_delete,
+        )
+
+    def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+        # first we have to delete the state groups states
+        logger.info("[purge] removing %s from state_groups_state", room_id)
+
+        self.db.simple_delete_many_txn(
+            txn,
+            table="state_groups_state",
+            column="state_group",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+        )
+
+        # ... and the state group edges
+        logger.info("[purge] removing %s from state_group_edges", room_id)
+
+        self.db.simple_delete_many_txn(
+            txn,
+            table="state_group_edges",
+            column="state_group",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+        )
+
+        # ... and the state groups
+        logger.info("[purge] removing %s from state_groups", room_id)
+
+        self.db.simple_delete_many_txn(
+            txn,
+            table="state_groups",
+            column="id",
+            iterable=state_groups_to_delete,
+            keyvalues={},
+        )
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
 
 from prometheus_client import Histogram
 
+from twisted.enterprise import adbapi
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 }
 
 
+def make_pool(
+    reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+    """Get the connection pool for the database.
+    """
+
+    return adbapi.ConnectionPool(
+        db_config.config["name"],
+        cp_reactor=reactor,
+        cp_openfun=engine.on_new_connection,
+        **db_config.config.get("args", {})
+    )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+    """Make a new connection to the database and return it.
+
+    Returns:
+        Connection
+    """
+
+    db_params = {
+        k: v
+        for k, v in db_config.config.get("args", {}).items()
+        if not k.startswith("cp_")
+    }
+    db_conn = engine.module.connect(**db_params)
+    engine.on_new_connection(db_conn)
+    return db_conn
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
 
     _TXN_ID = 0
 
-    def __init__(self, hs):
+    def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
         self.hs = hs
         self._clock = hs.get_clock()
-        self._db_pool = hs.get_db_pool()
+        self._database_config = database_config
+        self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
 
         self.updates = BackgroundUpdater(hs, self)
 
@@ -234,7 +268,7 @@ class Database(object):
         #   to watch it
         self._txn_perf_counters = PerformanceCounters()
 
-        self.engine = hs.database_engine
+        self.engine = engine
 
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
                 self._check_safe_to_upsert,
             )
 
+    def is_running(self):
+        """Is the database pool currently running
+        """
+        return self._db_pool.running
+
     @defer.inlineCallbacks
     def _check_safe_to_upsert(self):
         """
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index cbc74cd302..df039a072d 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -16,8 +16,6 @@
 import struct
 import threading
 
-from synapse.storage.prepare_database import prepare_database
-
 
 class Sqlite3Engine(object):
     single_threaded = True
@@ -62,6 +60,10 @@ class Sqlite3Engine(object):
         return sql
 
     def on_new_connection(self, db_conn):
+
+        # We need to import here to avoid an import loop.
+        from synapse.storage.prepare_database import prepare_database
+
         if self._is_in_memory:
             # In memory databases need to be rebuilt each time. Ideally we'd
             # reuse the same connection as we do when starting up, but that
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa03ca9ff7..1ed44925fc 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -183,7 +183,7 @@ class EventsPersistenceStorage(object):
         # so we use separate variables here even though they point to the same
         # store for now.
         self.main_store = stores.main
-        self.state_store = stores.main
+        self.state_store = stores.state
 
         self._clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 731e1c9d9c..e70026b80a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,6 +18,7 @@ import imp
 import logging
 import os
 import re
+from collections import Counter
 
 import attr
 
@@ -41,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
     pass
 
 
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
     """Prepares a database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
@@ -54,11 +55,10 @@ def prepare_database(db_conn, database_engine, config):
         config (synapse.config.homeserver.HomeServerConfig|None):
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
+        data_stores (list[str]): The name of the data stores that will be used
+            with this database. Defaults to all data stores.
     """
 
-    # For now we only have the one datastore.
-    data_stores = ["main"]
-
     try:
         cur = db_conn.cursor()
         version_info = _get_or_create_schema_state(cur, database_engine)
@@ -70,7 +70,10 @@ def prepare_database(db_conn, database_engine, config):
                 if user_version != SCHEMA_VERSION:
                     # If we don't pass in a config file then we are expecting to
                     # have already upgraded the DB.
-                    raise UpgradeDatabaseException("Database needs to be upgraded")
+                    raise UpgradeDatabaseException(
+                        "Expected database schema version %i but got %i"
+                        % (SCHEMA_VERSION, user_version)
+                    )
             else:
                 _upgrade_existing_database(
                     cur,
@@ -313,6 +316,9 @@ def _upgrade_existing_database(
                 )
             )
 
+        # Used to check if we have any duplicate file names
+        file_name_counter = Counter()
+
         # Now find which directories have anything of interest.
         directory_entries = []
         for directory in directories:
@@ -323,6 +329,9 @@ def _upgrade_existing_database(
                     _DirectoryListing(file_name, os.path.join(directory, file_name))
                     for file_name in file_names
                 )
+
+                for file_name in file_names:
+                    file_name_counter[file_name] += 1
             except FileNotFoundError:
                 # Data stores can have empty entries for a given version delta.
                 pass
@@ -331,6 +340,17 @@ def _upgrade_existing_database(
                     "Could not open delta dir for version %d: %s" % (v, directory)
                 )
 
+        duplicates = set(
+            file_name for file_name, count in file_name_counter.items() if count > 1
+        )
+        if duplicates:
+            # We don't support using the same file name in the same delta version.
+            raise PrepareDatabaseException(
+                "Found multiple delta files with the same name in v%d: %s",
+                v,
+                duplicates,
+            )
+
         # We sort to ensure that we apply the delta files in a consistent
         # order (to avoid bugs caused by inconsistent directory listing order)
         directory_entries.sort()
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index a368182034..d6a7bd7834 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -58,7 +58,7 @@ class PurgeEventsStorage(object):
 
         sg_to_delete = yield self._find_unreferenced_groups(state_groups)
 
-        yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete)
+        yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
 
     @defer.inlineCallbacks
     def _find_unreferenced_groups(self, state_groups):
@@ -102,7 +102,7 @@ class PurgeEventsStorage(object):
             # groups that are referenced.
             current_search -= referenced
 
-            edges = yield self.stores.main.get_previous_state_groups(current_search)
+            edges = yield self.stores.state.get_previous_state_groups(current_search)
 
             prevs = set(edges.values())
             # We don't bother re-handling groups we've already seen
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 3735846899..cbeb586014 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -342,7 +342,7 @@ class StateGroupStorage(object):
                 (prev_group, delta_ids)
         """
 
-        return self.stores.main.get_state_group_delta(state_group)
+        return self.stores.state.get_state_group_delta(state_group)
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, _room_id, event_ids):
@@ -362,7 +362,7 @@ class StateGroupStorage(object):
         event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self.stores.main._get_state_for_groups(groups)
+        group_to_state = yield self.stores.state._get_state_for_groups(groups)
 
         return group_to_state
 
@@ -423,7 +423,7 @@ class StateGroupStorage(object):
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
 
-        return self.stores.main._get_state_groups_from_groups(groups, state_filter)
+        return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     @defer.inlineCallbacks
     def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -439,7 +439,7 @@ class StateGroupStorage(object):
         event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self.stores.main._get_state_for_groups(
+        group_to_state = yield self.stores.state._get_state_for_groups(
             groups, state_filter
         )
 
@@ -476,7 +476,7 @@ class StateGroupStorage(object):
         event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self.stores.main._get_state_for_groups(
+        group_to_state = yield self.stores.state._get_state_for_groups(
             groups, state_filter
         )
 
@@ -532,7 +532,7 @@ class StateGroupStorage(object):
             Deferred[dict[int, dict[tuple[str, str], str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
-        return self.stores.main._get_state_for_groups(groups, state_filter)
+        return self.stores.state._get_state_for_groups(groups, state_filter)
 
     def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
@@ -552,6 +552,6 @@ class StateGroupStorage(object):
         Returns:
             Deferred[int]: The state group ID
         """
-        return self.stores.main.store_state_group(
+        return self.stores.state.store_state_group(
             event_id, room_id, prev_group, delta_ids, current_state_ids
         )
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index b91fb2db7b..fcd2aaa9c9 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, Dict
+
 from twisted.internet import defer
 
 from synapse.handlers.account_data import AccountDataEventSource
@@ -35,7 +37,7 @@ class EventSources(object):
     def __init__(self, hs):
         self.sources = {
             name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
-        }
+        }  # type: Dict[str, Any]
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 84f5ae22c3..2e8f6543e5 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -271,7 +271,7 @@ class _CacheDescriptorBase(object):
         else:
             self.function_to_call = orig
 
-        arg_spec = inspect.getargspec(orig)
+        arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
 
         if "cache_context" in all_args:
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
-    """Cache for snapshots like the response of /initialSync.
-    The response of initialSync only has to be a recent snapshot of the
-    server state. It shouldn't matter to clients if it is a few minutes out
-    of date.
-
-    This caches a deferred response. Until the deferred completes it will be
-    returned from the cache. This means that if the client retries the request
-    while the response is still being computed, that original response will be
-    used rather than trying to compute a new response.
-
-    Once the deferred completes it will removed from the cache after 5 minutes.
-    We delay removing it from the cache because a client retrying its request
-    could race with us finishing computing the response.
-
-    Rather than tracking precisely how long something has been in the cache we
-    keep two generations of completed responses. Every 5 minutes discard the
-    old generation, move the new generation to the old generation, and set the
-    new generation to be empty. This means that a result will be in the cache
-    somewhere between 5 and 10 minutes.
-    """
-
-    DURATION_MS = 5 * 60 * 1000  # Cache results for 5 minutes.
-
-    def __init__(self):
-        self.pending_result_cache = {}  # Request that haven't finished yet.
-        self.prev_result_cache = {}  # The older requests that have finished.
-        self.next_result_cache = {}  # The newer requests that have finished.
-        self.time_last_rotated_ms = 0
-
-    def rotate(self, time_now_ms):
-        # Rotate once if the cache duration has passed since the last rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms += self.DURATION_MS
-
-        # Rotate again if the cache duration has passed twice since the last
-        # rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms = time_now_ms
-
-    def get(self, time_now_ms, key):
-        self.rotate(time_now_ms)
-        # This cache is intended to deduplicate requests, so we expect it to be
-        # missed most of the time. So we just lookup the key in all of the
-        # dictionaries rather than trying to short circuit the lookup if the
-        # key is found.
-        result = self.prev_result_cache.get(key)
-        result = self.next_result_cache.get(key, result)
-        result = self.pending_result_cache.get(key, result)
-        if result is not None:
-            return result.observe()
-        else:
-            return None
-
-    def set(self, time_now_ms, key, deferred):
-        self.rotate(time_now_ms)
-
-        result = ObservableDeferred(deferred)
-
-        self.pending_result_cache[key] = result
-
-        def shuffle_along(r):
-            # When the deferred completes we shuffle it along to the first
-            # generation of the result cache. So that it will eventually
-            # expire from the rotation of that cache.
-            self.next_result_cache[key] = result
-            self.pending_result_cache.pop(key, None)
-            return r
-
-        result.addBoth(shuffle_along)
-
-        return result.observe()