summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-11 17:01:41 +0000
committerErik Johnston <erik@matrix.org>2019-12-11 17:01:41 +0000
commit6828b47c459ea15d2815fe880616aa260a546838 (patch)
treebdcc359e06bd7f7f977c976fe1d770bd708b0f55 /synapse
parentNewsfile (diff)
parentAdd `include_event_in_state` to _get_state_for_room (#6521) (diff)
downloadsynapse-6828b47c459ea15d2815fe880616aa260a546838.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/initial_sync_asnyc
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/app/federation_sender.py5
-rw-r--r--synapse/app/homeserver.py41
-rw-r--r--synapse/app/pusher.py7
-rw-r--r--synapse/app/user_dir.py5
-rw-r--r--synapse/config/saml2_config.py186
-rw-r--r--synapse/event_auth.py8
-rw-r--r--synapse/federation/federation_client.py191
-rw-r--r--synapse/federation/federation_server.py15
-rw-r--r--synapse/federation/transport/client.py33
-rw-r--r--synapse/federation/transport/server.py32
-rw-r--r--synapse/handlers/account_data.py14
-rw-r--r--synapse/handlers/account_validity.py86
-rw-r--r--synapse/handlers/e2e_keys.py38
-rw-r--r--synapse/handlers/federation.py286
-rw-r--r--synapse/handlers/message.py5
-rw-r--r--synapse/handlers/pagination.py27
-rw-r--r--synapse/handlers/saml_handler.py198
-rw-r--r--synapse/handlers/search.py34
-rw-r--r--synapse/logging/context.py11
-rw-r--r--synapse/replication/slave/storage/_base.py5
-rw-r--r--synapse/replication/slave/storage/account_data.py5
-rw-r--r--synapse/replication/slave/storage/client_ips.py5
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py5
-rw-r--r--synapse/replication/slave/storage/devices.py5
-rw-r--r--synapse/replication/slave/storage/events.py5
-rw-r--r--synapse/replication/slave/storage/filtering.py5
-rw-r--r--synapse/replication/slave/storage/groups.py5
-rw-r--r--synapse/replication/slave/storage/presence.py5
-rw-r--r--synapse/replication/slave/storage/push_rule.py5
-rw-r--r--synapse/replication/slave/storage/pushers.py5
-rw-r--r--synapse/replication/slave/storage/receipts.py5
-rw-r--r--synapse/replication/slave/storage/room.py5
-rw-r--r--synapse/rest/client/v1/profile.py15
-rw-r--r--synapse/server.py3
-rw-r--r--synapse/state/__init__.py3
-rw-r--r--synapse/storage/__init__.py12
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/data_stores/__init__.py16
-rw-r--r--synapse/storage/data_stores/main/__init__.py29
-rw-r--r--synapse/storage/data_stores/main/account_data.py9
-rw-r--r--synapse/storage/data_stores/main/appservice.py5
-rw-r--r--synapse/storage/data_stores/main/client_ips.py17
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py9
-rw-r--r--synapse/storage/data_stores/main/devices.py9
-rw-r--r--synapse/storage/data_stores/main/event_federation.py5
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py9
-rw-r--r--synapse/storage/data_stores/main/events.py28
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py13
-rw-r--r--synapse/storage/data_stores/main/events_worker.py102
-rw-r--r--synapse/storage/data_stores/main/media_repository.py11
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py7
-rw-r--r--synapse/storage/data_stores/main/push_rule.py5
-rw-r--r--synapse/storage/data_stores/main/receipts.py9
-rw-r--r--synapse/storage/data_stores/main/registration.py13
-rw-r--r--synapse/storage/data_stores/main/room.py14
-rw-r--r--synapse/storage/data_stores/main/roommember.py13
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql4
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql16
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.md13
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.txt19
-rw-r--r--synapse/storage/data_stores/main/search.py17
-rw-r--r--synapse/storage/data_stores/main/state.py31
-rw-r--r--synapse/storage/data_stores/main/stats.py5
-rw-r--r--synapse/storage/data_stores/main/stream.py5
-rw-r--r--synapse/storage/data_stores/main/transactions.py5
-rw-r--r--synapse/storage/data_stores/main/user_directory.py9
-rw-r--r--synapse/storage/database.py38
70 files changed, 1121 insertions, 693 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index f99de2f3f3..fc2a6e4ee6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.6.1"
+__version__ = "1.7.0rc2"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 448e45e00f..f24920a7d6 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.replication.tcp.streams._base import ReceiptsStream
 from synapse.server import HomeServer
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.types import ReadReceipt
 from synapse.util.async_helpers import Linearizer
@@ -59,8 +60,8 @@ class FederationSenderSlaveStore(
     SlavedDeviceStore,
     SlavedPresenceStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs)
 
         # We pull out the current federation stream position now so that we
         # always have a known value for the federation position in memory so
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 9f81a857ab..032010600a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -68,9 +68,9 @@ from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.rest.media.v0.content_repository import ContentRepoResource
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
-from synapse.storage import DataStore, are_all_users_on_domain
+from synapse.storage import DataStore
 from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
-from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.storage.prepare_database import UpgradeDatabaseException
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
@@ -294,22 +294,6 @@ class SynapseHomeServer(HomeServer):
             else:
                 logger.warning("Unrecognized listener type: %s", listener["type"])
 
-    def run_startup_checks(self, db_conn, database_engine):
-        all_users_native = are_all_users_on_domain(
-            db_conn.cursor(), database_engine, self.hostname
-        )
-        if not all_users_native:
-            quit_with_error(
-                "Found users in database not native to %s!\n"
-                "You cannot changed a synapse server_name after it's been configured"
-                % (self.hostname,)
-            )
-
-        try:
-            database_engine.check_database(db_conn.cursor())
-        except IncorrectDatabaseSetup as e:
-            quit_with_error(str(e))
-
 
 # Gauges to expose monthly active user control metrics
 current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
@@ -357,16 +341,12 @@ def setup(config_options):
 
     synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
 
-    logger.info("Preparing database: %s...", config.database_config["name"])
+    logger.info("Setting up server")
 
     try:
-        with hs.get_db_conn(run_new_connection=False) as db_conn:
-            prepare_database(db_conn, database_engine, config=config)
-            database_engine.on_new_connection(db_conn)
-
-            hs.run_startup_checks(db_conn, database_engine)
-
-            db_conn.commit()
+        hs.setup()
+    except IncorrectDatabaseSetup as e:
+        quit_with_error(str(e))
     except UpgradeDatabaseException:
         sys.stderr.write(
             "\nFailed to upgrade database.\n"
@@ -375,9 +355,6 @@ def setup(config_options):
         )
         sys.exit(1)
 
-    logger.info("Database prepared in %s.", config.database_config["name"])
-
-    hs.setup()
     hs.setup_master()
 
     @defer.inlineCallbacks
@@ -542,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
     # Database version
     #
 
-    stats["database_engine"] = hs.database_engine.module.__name__
-    stats["database_server_version"] = hs.database_engine.server_version
+    # This only reports info about the *main* database.
+    stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+    stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
     logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
     try:
         yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 01a5ffc363..dd52a9fc2d 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -33,6 +33,7 @@ from synapse.replication.slave.storage.account_data import SlavedAccountDataStor
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.pushers import SlavedPusherStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -45,7 +46,11 @@ logger = logging.getLogger("synapse.app.pusher")
 
 
 class PusherSlaveStore(
-    SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore
+    SlavedEventStore,
+    SlavedPusherStore,
+    SlavedReceiptsStore,
+    SlavedAccountDataStore,
+    RoomStore,
 ):
     update_pusher_last_stream_ordering_and_success = __func__(
         DataStore.update_pusher_last_stream_ordering_and_success
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index b6d4481725..c01fb34a9b 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import (
 from synapse.rest.client.v2_alpha import user_directory
 from synapse.server import HomeServer
 from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.httpresourcetree import create_resource_tree
@@ -60,8 +61,8 @@ class UserDirectorySlaveStore(
     UserDirectoryStore,
     BaseSlavedStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
+import logging
 
 from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
-    map_username_to_mxid_localpart,
-    mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
 
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+    "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
 
 def _dict_merge(merge_dict, into_dict):
     """Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
-        self.saml2_mxid_source_attribute = saml2_config.get(
-            "mxid_source_attribute", "uid"
-        )
-
         self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
             "grandfathered_mxid_source_attribute", "uid"
         )
 
-        saml2_config_dict = self._default_saml_config_dict()
+        # user_mapping_provider may be None if the key is present but has no value
+        ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+        # Use the default user mapping provider if not set
+        ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+        # Ensure a config is present
+        ump_dict["config"] = ump_dict.get("config") or {}
+
+        if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+            # Load deprecated options for use by the default module
+            old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+            if old_mxid_source_attribute:
+                logger.warning(
+                    "The config option saml2_config.mxid_source_attribute is deprecated. "
+                    "Please use saml2_config.user_mapping_provider.config"
+                    ".mxid_source_attribute instead."
+                )
+                ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+            old_mxid_mapping = saml2_config.get("mxid_mapping")
+            if old_mxid_mapping:
+                logger.warning(
+                    "The config option saml2_config.mxid_mapping is deprecated. Please "
+                    "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+                )
+                ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+        # Retrieve an instance of the module's class
+        # Pass the config dictionary to the module for processing
+        (
+            self.saml2_user_mapping_provider_class,
+            self.saml2_user_mapping_provider_config,
+        ) = load_module(ump_dict)
+
+        # Ensure loaded user mapping module has defined all necessary methods
+        # Note parse_config() is already checked during the call to load_module
+        required_methods = [
+            "get_saml_attributes",
+            "saml_response_to_user_attributes",
+        ]
+        missing_methods = [
+            method
+            for method in required_methods
+            if not hasattr(self.saml2_user_mapping_provider_class, method)
+        ]
+        if missing_methods:
+            raise ConfigError(
+                "Class specified by saml2_config."
+                "user_mapping_provider.module is missing required "
+                "methods: %s" % (", ".join(missing_methods),)
+            )
+
+        # Get the desired saml auth response attributes from the module
+        saml2_config_dict = self._default_saml_config_dict(
+            *self.saml2_user_mapping_provider_class.get_saml_attributes(
+                self.saml2_user_mapping_provider_config
+            )
+        )
         _dict_merge(
             merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
         )
@@ -103,22 +159,27 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "5m")
         )
 
-        mapping = saml2_config.get("mxid_mapping", "hexencode")
-        try:
-            self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
-        except KeyError:
-            raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
-    def _default_saml_config_dict(self):
+    def _default_saml_config_dict(
+        self, required_attributes: set, optional_attributes: set
+    ):
+        """Generate a configuration dictionary with required and optional attributes that
+        will be needed to process new user registration
+
+        Args:
+            required_attributes: SAML auth response attributes that are
+                necessary to function
+            optional_attributes: SAML auth response attributes that can be used to add
+                additional information to Synapse user accounts, but are not required
+
+        Returns:
+            dict: A SAML configuration dictionary
+        """
         import saml2
 
         public_baseurl = self.public_baseurl
         if public_baseurl is None:
             raise ConfigError("saml2_config requires a public_baseurl to be set")
 
-        required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
-        optional_attributes = {"displayName"}
         if self.saml2_grandfathered_mxid_source_attribute:
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
         optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
           #
           #config_path: "%(config_dir_path)s/sp_conf.py"
 
-          # the lifetime of a SAML session. This defines how long a user has to
+          # The lifetime of a SAML session. This defines how long a user has to
           # complete the authentication process, if allow_unsolicited is unset.
           # The default is 5 minutes.
           #
           #saml_session_lifetime: 5m
 
-          # The SAML attribute (after mapping via the attribute maps) to use to derive
-          # the Matrix ID from. 'uid' by default.
+          # An external module can be provided here as a custom solution to
+          # mapping attributes returned from a saml provider onto a matrix user.
           #
-          #mxid_source_attribute: displayName
-
-          # The mapping system to use for mapping the saml attribute onto a matrix ID.
-          # Options include:
-          #  * 'hexencode' (which maps unpermitted characters to '=xx')
-          #  * 'dotreplace' (which replaces unpermitted characters with '.').
-          # The default is 'hexencode'.
-          #
-          #mxid_mapping: dotreplace
-
-          # In previous versions of synapse, the mapping from SAML attribute to MXID was
-          # always calculated dynamically rather than stored in a table. For backwards-
-          # compatibility, we will look for user_ids matching such a pattern before
-          # creating a new account.
+          user_mapping_provider:
+            # The custom module's class. Uncomment to use a custom module.
+            #
+            #module: mapping_provider.SamlMappingProvider
+
+            # Custom configuration values for the module. Below options are
+            # intended for the built-in provider, they should be changed if
+            # using a custom module. This section will be passed as a Python
+            # dictionary to the module's `parse_config` method.
+            #
+            config:
+              # The SAML attribute (after mapping via the attribute maps) to use
+              # to derive the Matrix ID from. 'uid' by default.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_source_attribute option. If that is still
+              # defined, its value will be used instead.
+              #
+              #mxid_source_attribute: displayName
+
+              # The mapping system to use for mapping the saml attribute onto a
+              # matrix ID.
+              #
+              # Options include:
+              #  * 'hexencode' (which maps unpermitted characters to '=xx')
+              #  * 'dotreplace' (which replaces unpermitted characters with
+              #     '.').
+              # The default is 'hexencode'.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_mapping option. If that is still defined, its
+              # value will be used instead.
+              #
+              #mxid_mapping: dotreplace
+
+          # In previous versions of synapse, the mapping from SAML attribute to
+          # MXID was always calculated dynamically rather than stored in a
+          # table. For backwards- compatibility, we will look for user_ids
+          # matching such a pattern before creating a new account.
           #
           # This setting controls the SAML attribute which will be used for this
-          # backwards-compatibility lookup. Typically it should be 'uid', but if the
-          # attribute maps are changed, it may be necessary to change it.
+          # backwards-compatibility lookup. Typically it should be 'uid', but if
+          # the attribute maps are changed, it may be necessary to change it.
           #
           # The default is 'uid'.
           #
@@ -241,23 +327,3 @@ class SAML2Config(Config):
         """ % {
             "config_dir_path": config_dir_path
         }
-
-
-DOT_REPLACE_PATTERN = re.compile(
-    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
-    username = username.lower()
-    username = DOT_REPLACE_PATTERN.sub(".", username)
-
-    # regular mxids aren't allowed to start with an underscore either
-    username = re.sub("^_", "", username)
-    return username
-
-
-MXID_MAPPER_MAP = {
-    "hexencode": map_username_to_mxid_localpart,
-    "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..c940b84470 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
     Returns:
          if the auth checks pass.
     """
+    assert isinstance(auth_events, dict)
+
     if do_size_check:
         _check_size_limits(event)
 
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
-    if auth_events is None:
-        # Oh, we don't know what the state of the room was, so we
-        # are trusting that this is allowed (at least for now)
-        logger.warning("Trusting event: %s", event.event_id)
-        return
-
     if event.type == EventTypes.Create:
         sender_domain = get_domain_from_id(event.sender)
         room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..af652a7659 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
 import itertools
 import logging
 
-from six.moves import range
-
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.utils import log_function
 from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,19 +308,12 @@ class FederationClient(FederationBase):
         return signed_pdu
 
     @defer.inlineCallbacks
-    @log_function
-    def get_state_for_room(self, destination, room_id, event_id):
-        """Requests all of the room state at a given event from a remote homeserver.
-
-        Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+    def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+        """Calls the /state_ids endpoint to fetch the state at a particular point
+        in the room, and the auth events for the given event
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            Tuple[List[str], List[str]]:  a tuple of (state event_ids, auth event_ids)
         """
         result = yield self.transport_layer.get_room_state_ids(
             destination, room_id, event_id=event_id
@@ -331,86 +322,12 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
-        fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-            destination, room_id, set(state_event_ids + auth_event_ids)
-        )
-
-        if failed_to_fetch:
-            logger.warning(
-                "Failed to fetch missing state/auth events for %s: %s",
-                room_id,
-                failed_to_fetch,
-            )
-
-        event_map = {ev.event_id: ev for ev in fetched_events}
+        if not isinstance(state_event_ids, list) or not isinstance(
+            auth_event_ids, list
+        ):
+            raise Exception("invalid response from /state_ids")
 
-        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-
-        auth_chain.sort(key=lambda e: e.depth)
-
-        return pdus, auth_chain
-
-    @defer.inlineCallbacks
-    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
-        """Fetch events from a remote destination, checking if we already have them.
-
-        Args:
-            destination (str)
-            room_id (str)
-            event_ids (list)
-
-        Returns:
-            Deferred: A deferred resolving to a 2-tuple where the first is a list of
-            events and the second is a list of event ids that we failed to fetch.
-        """
-        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
-        signed_events = list(seen_events.values())
-
-        failed_to_fetch = set()
-
-        missing_events = set(event_ids)
-        for k in seen_events:
-            missing_events.discard(k)
-
-        if not missing_events:
-            return signed_events, failed_to_fetch
-
-        logger.debug(
-            "Fetching unknown state/auth events %s for room %s",
-            missing_events,
-            event_ids,
-        )
-
-        room_version = yield self.store.get_room_version(room_id)
-
-        batch_size = 20
-        missing_events = list(missing_events)
-        for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i : i + batch_size])
-
-            deferreds = [
-                run_in_background(
-                    self.get_pdu,
-                    destinations=[destination],
-                    event_id=e_id,
-                    room_version=room_version,
-                )
-                for e_id in batch
-            ]
-
-            res = yield make_deferred_yieldable(
-                defer.DeferredList(deferreds, consumeErrors=True)
-            )
-            for success, result in res:
-                if success and result:
-                    signed_events.append(result)
-                    batch.discard(result.event_id)
-
-            # We removed all events we successfully fetched from `batch`
-            failed_to_fetch.update(batch)
-
-        return signed_events, failed_to_fetch
+        return state_event_ids, auth_event_ids
 
     @defer.inlineCallbacks
     @log_function
@@ -609,13 +526,7 @@ class FederationClient(FederationBase):
 
         @defer.inlineCallbacks
         def send_request(destination):
-            time_now = self._clock.time_msec()
-            _, content = yield self.transport_layer.send_join(
-                destination=destination,
-                room_id=pdu.room_id,
-                event_id=pdu.event_id,
-                content=pdu.get_pdu_json(time_now),
-            )
+            content = yield self._do_send_join(destination, pdu)
 
             logger.debug("Got content: %s", content)
 
@@ -683,6 +594,44 @@ class FederationClient(FederationBase):
         return self._try_destination_list("send_join", destinations, send_request)
 
     @defer.inlineCallbacks
+    def _do_send_join(self, destination, pdu):
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_join_v2(
+                destination=destination,
+                room_id=pdu.room_id,
+                event_id=pdu.event_id,
+                content=pdu.get_pdu_json(time_now),
+            )
+
+            return content
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                err = e.to_synapse_error()
+
+                # If we receive an error response that isn't a generic error, or an
+                # unrecognised endpoint error, we  assume that the remote understands
+                # the v2 invite API and this is a legitimate error.
+                if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+                    raise err
+            else:
+                raise e.to_synapse_error()
+
+        logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
+
+        resp = yield self.transport_layer.send_join_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+
+        # We expect the v1 API to respond with [200, content], so we only return the
+        # content.
+        return resp[1]
+
+    @defer.inlineCallbacks
     def send_invite(self, destination, room_id, event_id, pdu):
         room_version = yield self.store.get_room_version(room_id)
 
@@ -791,18 +740,50 @@ class FederationClient(FederationBase):
 
         @defer.inlineCallbacks
         def send_request(destination):
-            time_now = self._clock.time_msec()
-            _, content = yield self.transport_layer.send_leave(
+            content = yield self._do_send_leave(destination, pdu)
+
+            logger.debug("Got content: %s", content)
+            return None
+
+        return self._try_destination_list("send_leave", destinations, send_request)
+
+    @defer.inlineCallbacks
+    def _do_send_leave(self, destination, pdu):
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_leave_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
 
-            logger.debug("Got content: %s", content)
-            return None
+            return content
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                err = e.to_synapse_error()
 
-        return self._try_destination_list("send_leave", destinations, send_request)
+                # If we receive an error response that isn't a generic error, or an
+                # unrecognised endpoint error, we  assume that the remote understands
+                # the v2 invite API and this is a legitimate error.
+                if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+                    raise err
+            else:
+                raise e.to_synapse_error()
+
+        logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
+
+        resp = yield self.transport_layer.send_leave_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+
+        # We expect the v1 API to respond with [200, content], so we only return the
+        # content.
+        return resp[1]
 
     def get_public_rooms(
         self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 84d4eca041..d7ce333822 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -384,15 +384,10 @@ class FederationServer(FederationBase):
 
         res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
-        return (
-            200,
-            {
-                "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-                "auth_chain": [
-                    p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
-                ],
-            },
-        )
+        return {
+            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+            "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+        }
 
     async def on_make_leave_request(self, origin, room_id, user_id):
         origin_host, _ = parse_server_name(origin)
@@ -419,7 +414,7 @@ class FederationServer(FederationBase):
         pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         await self.handler.on_send_leave_request(origin, pdu)
-        return 200, {}
+        return {}
 
     async def on_event_auth(self, origin, room_id, event_id):
         with (await self._server_linearizer.queue((origin, room_id))):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 46dba84cac..198257414b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -243,7 +243,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_join(self, destination, room_id, event_id, content):
+    def send_join_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
@@ -254,7 +254,18 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_leave(self, destination, room_id, event_id, content):
+    def send_join_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination, path=path, data=content
+        )
+
+        return response
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_leave_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
@@ -272,6 +283,24 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def send_leave_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+            # we want to do our best to send this through. The problem is
+            # that if it fails, we won't retry it later, so if the remote
+            # server was just having a momentary blip, the room will be out of
+            # sync.
+            ignore_backoff=True,
+        )
+
+        return response
+
+    @defer.inlineCallbacks
+    @log_function
     def send_invite_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fefc789c85..b4cbf23394 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
         return 200, content
 
 
-class FederationSendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServlet):
     PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
     async def on_PUT(self, origin, content, query, room_id, event_id):
         content = await self.handler.on_send_leave_request(origin, content, room_id)
+        return 200, (200, content)
+
+
+class FederationV2SendLeaveServlet(BaseFederationServlet):
+    PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+    PREFIX = FEDERATION_V2_PREFIX
+
+    async def on_PUT(self, origin, content, query, room_id, event_id):
+        content = await self.handler.on_send_leave_request(origin, content, room_id)
         return 200, content
 
 
@@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
         return await self.handler.on_event_auth(origin, context, event_id)
 
 
-class FederationSendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServlet):
+    PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+    async def on_PUT(self, origin, content, query, context, event_id):
+        # TODO(paul): assert that context/event_id parsed from path actually
+        #   match those given in content
+        content = await self.handler.on_send_join_request(origin, content, context)
+        return 200, (200, content)
+
+
+class FederationV2SendJoinServlet(BaseFederationServlet):
     PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
+    PREFIX = FEDERATION_V2_PREFIX
+
     async def on_PUT(self, origin, content, query, context, event_id):
         # TODO(paul): assert that context/event_id parsed from path actually
         #   match those given in content
@@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
     FederationMakeJoinServlet,
     FederationMakeLeaveServlet,
     FederationEventServlet,
-    FederationSendJoinServlet,
-    FederationSendLeaveServlet,
+    FederationV1SendJoinServlet,
+    FederationV2SendJoinServlet,
+    FederationV1SendLeaveServlet,
+    FederationV2SendLeaveServlet,
     FederationV1InviteServlet,
     FederationV2InviteServlet,
     FederationQueryAuthServlet,
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..20ec1ca01b 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 
 class AccountDataEventSource(object):
     def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
     def get_current_key(self, direction="f"):
         return self.store.get_max_account_data_stream_id()
 
-    @defer.inlineCallbacks
-    def get_new_events(self, user, from_key, **kwargs):
+    async def get_new_events(self, user, from_key, **kwargs):
         user_id = user.to_string()
         last_stream_id = from_key
 
-        current_stream_id = yield self.store.get_max_account_data_stream_id()
+        current_stream_id = self.store.get_max_account_data_stream_id()
 
         results = []
-        tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+        tags = await self.store.get_updated_tags(user_id, last_stream_id)
 
         for room_id, room_tags in tags.items():
             results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
         (
             account_data,
             room_account_data,
-        ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+        ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
 
         for account_data_type, content in account_data.items():
             results.append({"type": account_data_type, "content": content})
@@ -54,6 +51,5 @@ class AccountDataEventSource(object):
 
         return results, current_stream_id
 
-    @defer.inlineCallbacks
-    def get_pagination_rows(self, user, config, key):
+    async def get_pagination_rows(self, user, config, key):
         return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
 import logging
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
 
 from synapse.api.errors import StoreError
 from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
                 # run as a background process to make sure that the database transactions
                 # have a logcontext to report to
                 return run_as_background_process(
-                    "send_renewals", self.send_renewal_emails
+                    "send_renewals", self._send_renewal_emails
                 )
 
             self.clock.looping_call(send_emails, 30 * 60 * 1000)
 
-    @defer.inlineCallbacks
-    def send_renewal_emails(self):
+    async def _send_renewal_emails(self):
         """Gets the list of users whose account is expiring in the amount of time
         configured in the ``renew_at`` parameter from the ``account_validity``
         configuration, and sends renewal emails to all of these users as long as they
         have an email 3PID attached to their account.
         """
-        expiring_users = yield self.store.get_users_expiring_soon()
+        expiring_users = await self.store.get_users_expiring_soon()
 
         if expiring_users:
             for user in expiring_users:
-                yield self._send_renewal_email(
+                await self._send_renewal_email(
                     user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
                 )
 
-    @defer.inlineCallbacks
-    def send_renewal_email_to_user(self, user_id):
-        expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
-        yield self._send_renewal_email(user_id, expiration_ts)
+    async def send_renewal_email_to_user(self, user_id: str):
+        expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+        await self._send_renewal_email(user_id, expiration_ts)
 
-    @defer.inlineCallbacks
-    def _send_renewal_email(self, user_id, expiration_ts):
+    async def _send_renewal_email(self, user_id: str, expiration_ts: int):
         """Sends out a renewal email to every email address attached to the given user
         with a unique link allowing them to renew their account.
 
         Args:
-            user_id (str): ID of the user to send email(s) to.
-            expiration_ts (int): Timestamp in milliseconds for the expiration date of
+            user_id: ID of the user to send email(s) to.
+            expiration_ts: Timestamp in milliseconds for the expiration date of
                 this user's account (used in the email templates).
         """
-        addresses = yield self._get_email_addresses_for_user(user_id)
+        addresses = await self._get_email_addresses_for_user(user_id)
 
         # Stop right here if the user doesn't have at least one email address.
         # In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
             return
 
         try:
-            user_display_name = yield self.store.get_profile_displayname(
+            user_display_name = await self.store.get_profile_displayname(
                 UserID.from_string(user_id).localpart
             )
             if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
         except StoreError:
             user_display_name = user_id
 
-        renewal_token = yield self._get_renewal_token(user_id)
+        renewal_token = await self._get_renewal_token(user_id)
         url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
             self.hs.config.public_baseurl,
             renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
 
             logger.info("Sending renewal email to %s", address)
 
-            yield make_deferred_yieldable(
+            await make_deferred_yieldable(
                 self.sendmail(
                     self.hs.config.email_smtp_host,
                     self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
                 )
             )
 
-        yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+        await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
 
-    @defer.inlineCallbacks
-    def _get_email_addresses_for_user(self, user_id):
+    async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
         """Retrieve the list of email addresses attached to a user's account.
 
         Args:
-            user_id (str): ID of the user to lookup email addresses for.
+            user_id: ID of the user to lookup email addresses for.
 
         Returns:
-            defer.Deferred[list[str]]: Email addresses for this account.
+            Email addresses for this account.
         """
-        threepids = yield self.store.user_get_threepids(user_id)
+        threepids = await self.store.user_get_threepids(user_id)
 
         addresses = []
         for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
 
         return addresses
 
-    @defer.inlineCallbacks
-    def _get_renewal_token(self, user_id):
+    async def _get_renewal_token(self, user_id: str) -> str:
         """Generates a 32-byte long random string that will be inserted into the
         user's renewal email's unique link, then saves it into the database.
 
         Args:
-            user_id (str): ID of the user to generate a string for.
+            user_id: ID of the user to generate a string for.
 
         Returns:
-            defer.Deferred[str]: The generated string.
+            The generated string.
 
         Raises:
             StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
         while attempts < 5:
             try:
                 renewal_token = stringutils.random_string(32)
-                yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+                await self.store.set_renewal_token_for_user(user_id, renewal_token)
                 return renewal_token
             except StoreError:
                 attempts += 1
         raise StoreError(500, "Couldn't generate a unique string as refresh string.")
 
-    @defer.inlineCallbacks
-    def renew_account(self, renewal_token):
+    async def renew_account(self, renewal_token: str) -> bool:
         """Renews the account attached to a given renewal token by pushing back the
         expiration date by the current validity period in the server's configuration.
 
         Args:
-            renewal_token (str): Token sent with the renewal request.
+            renewal_token: Token sent with the renewal request.
         Returns:
-            bool: Whether the provided token is valid.
+            Whether the provided token is valid.
         """
         try:
-            user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+            user_id = await self.store.get_user_from_renewal_token(renewal_token)
         except StoreError:
-            defer.returnValue(False)
+            return False
 
         logger.debug("Renewing an account for user %s", user_id)
-        yield self.renew_account_for_user(user_id)
+        await self.renew_account_for_user(user_id)
 
-        defer.returnValue(True)
+        return True
 
-    @defer.inlineCallbacks
-    def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+    async def renew_account_for_user(
+        self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+    ) -> int:
         """Renews the account attached to a given user by pushing back the
         expiration date by the current validity period in the server's
         configuration.
 
         Args:
-            renewal_token (str): Token sent with the renewal request.
-            expiration_ts (int): New expiration date. Defaults to now + validity period.
-            email_sent (bool): Whether an email has been sent for this validity period.
+            renewal_token: Token sent with the renewal request.
+            expiration_ts: New expiration date. Defaults to now + validity period.
+            email_sen: Whether an email has been sent for this validity period.
                 Defaults to False.
 
         Returns:
-            defer.Deferred[int]: New expiration date for this account, as a timestamp
-                in milliseconds since epoch.
+            New expiration date for this account, as a timestamp in
+            milliseconds since epoch.
         """
         if expiration_ts is None:
             expiration_ts = self.clock.time_msec() + self._account_validity.period
 
-        yield self.store.set_account_validity_for_user(
+        await self.store.set_account_validity_for_user(
             user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
         )
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 28c12753c1..57a10daefd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,7 +264,6 @@ class E2eKeysHandler(object):
 
         return ret
 
-    @defer.inlineCallbacks
     def get_cross_signing_keys_from_cache(self, query, from_user_id):
         """Get cross-signing keys for users from the database
 
@@ -284,35 +283,14 @@ class E2eKeysHandler(object):
         self_signing_keys = {}
         user_signing_keys = {}
 
-        for user_id in query:
-            # XXX: consider changing the store functions to allow querying
-            # multiple users simultaneously.
-            key = yield self.store.get_e2e_cross_signing_key(
-                user_id, "master", from_user_id
-            )
-            if key:
-                master_keys[user_id] = key
-
-            key = yield self.store.get_e2e_cross_signing_key(
-                user_id, "self_signing", from_user_id
-            )
-            if key:
-                self_signing_keys[user_id] = key
-
-            # users can see other users' master and self-signing keys, but can
-            # only see their own user-signing keys
-            if from_user_id == user_id:
-                key = yield self.store.get_e2e_cross_signing_key(
-                    user_id, "user_signing", from_user_id
-                )
-                if key:
-                    user_signing_keys[user_id] = key
-
-        return {
-            "master_keys": master_keys,
-            "self_signing_keys": self_signing_keys,
-            "user_signing_keys": user_signing_keys,
-        }
+        # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
+        return defer.succeed(
+            {
+                "master_keys": master_keys,
+                "self_signing_keys": self_signing_keys,
+                "user_signing_keys": user_signing_keys,
+            }
+        )
 
     @trace
     @defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..62985bab9f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
 
 import itertools
 import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
 
 import six
 from six import iteritems, itervalues
@@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
 )
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
             pdu (FrozenEvent): received PDU
             sent_to_us_directly (bool): True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
-
-        Returns (Deferred): completes with None
         """
 
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+        logger.info("handling received PDU: %s", pdu)
 
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self.store.get_event(
+        existing = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
@@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
                 "[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
-            min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+            min_depth = await self.get_min_depth_for_context(pdu.room_id)
 
             logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
 
             prevs = set(pdu.prev_event_ids())
-            seen = yield self.store.have_seen_events(prevs)
+            seen = await self.store.have_seen_events(prevs)
 
             if min_depth and pdu.depth < min_depth:
                 # This is so that we don't notify the user about this
@@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
                         len(missing_prevs),
                         shortstr(missing_prevs),
                     )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                    with (await self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
                             "[%s %s] Acquired room lock to fetch %d missing prev_events",
                             room_id,
@@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
                             len(missing_prevs),
                         )
 
-                        yield self._get_missing_events_for_pdu(
-                            origin, pdu, prevs, min_depth
-                        )
+                        try:
+                            await self._get_missing_events_for_pdu(
+                                origin, pdu, prevs, min_depth
+                            )
+                        except Exception as e:
+                            raise Exception(
+                                "Error fetching missing prev_events for %s: %s"
+                                % (event_id, e)
+                            )
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
-                        seen = yield self.store.have_seen_events(prevs)
+                        seen = await self.store.have_seen_events(prevs)
 
                         if not prevs - seen:
                             logger.info(
@@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
                                 room_id,
                                 event_id,
                             )
-                elif missing_prevs:
-                    logger.info(
-                        "[%s %s] Not recursively fetching %d missing prev_events: %s",
-                        room_id,
-                        event_id,
-                        len(missing_prevs),
-                        shortstr(missing_prevs),
-                    )
 
             if prevs - seen:
                 # We've still not been able to get all of the prev_events for this event.
@@ -344,13 +340,19 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
+                logger.info(
+                    "Event %s is missing prev_events: calculating state for a "
+                    "backwards extremity",
+                    event_id,
+                )
+
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
                 auth_chains = set()
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+                    ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
                     state_maps = list(
@@ -364,13 +366,10 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         logger.info(
-                            "[%s %s] Requesting state at missing prev_event %s",
-                            room_id,
-                            event_id,
-                            p,
+                            "Requesting state at missing prev_event %s", event_id,
                         )
 
-                        room_version = yield self.store.get_room_version(room_id)
+                        room_version = await self.store.get_room_version(room_id)
 
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
@@ -379,24 +378,10 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self.federation_client.get_state_for_room(
-                                origin, room_id, p
-                            )
-
-                            # we want the state *after* p; get_state_for_room returns the
-                            # state *before* p.
-                            remote_event = yield self.federation_client.get_pdu(
-                                [origin], p, room_version, outlier=True
+                            ) = await self._get_state_for_room(
+                                origin, room_id, p, include_event_in_state=True
                             )
 
-                            if remote_event is None:
-                                raise Exception(
-                                    "Unable to get missing prev_event %s" % (p,)
-                                )
-
-                            if remote_event.is_state():
-                                remote_state.append(remote_event)
-
                             # XXX hrm I'm not convinced that duplicate events will compare
                             # for equality, so I'm not sure this does what the author
                             # hoped.
@@ -410,7 +395,7 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
-                    state_map = yield resolve_events_with_store(
+                    state_map = await resolve_events_with_store(
                         room_version,
                         state_maps,
                         event_map,
@@ -422,10 +407,10 @@ class FederationHandler(BaseHandler):
 
                     # First though we need to fetch all the events that are in
                     # state_map, so we can build up the state below.
-                    evs = yield self.store.get_events(
+                    evs = await self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
-                        check_redacted=False,
+                        redact_behaviour=EventRedactBehaviour.AS_IS,
                     )
                     event_map.update(evs)
 
@@ -446,12 +431,11 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        yield self._process_received_pdu(
+        await self._process_received_pdu(
             origin, pdu, state=state, auth_chain=auth_chain
         )
 
-    @defer.inlineCallbacks
-    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
         """
         Args:
             origin (str): Origin of the pdu. Will be called to get the missing events
@@ -463,12 +447,12 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        seen = yield self.store.have_seen_events(prevs)
+        seen = await self.store.have_seen_events(prevs)
 
         if not prevs - seen:
             return
 
-        latest = yield self.store.get_latest_event_ids_in_room(room_id)
+        latest = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
@@ -532,7 +516,7 @@ class FederationHandler(BaseHandler):
         # All that said: Let's try increasing the timout to 60s and see what happens.
 
         try:
-            missing_events = yield self.federation_client.get_missing_events(
+            missing_events = await self.federation_client.get_missing_events(
                 origin,
                 room_id,
                 earliest_events_ids=list(latest),
@@ -571,7 +555,7 @@ class FederationHandler(BaseHandler):
             )
             with nested_logging_context(ev.event_id):
                 try:
-                    yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
                 except FederationError as e:
                     if e.code == 403:
                         logger.warning(
@@ -583,8 +567,116 @@ class FederationHandler(BaseHandler):
                     else:
                         raise
 
-    @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, event, state, auth_chain):
+    async def _get_state_for_room(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        include_event_in_state: bool = False,
+    ) -> Tuple[List[EventBase], List[EventBase]]:
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination: The remote homeserver to query for the state.
+            room_id: The id of the room we're interested in.
+            event_id: The id of the event we want the state at.
+            include_event_in_state: if true, the event itself will be included in the
+                returned state event list.
+
+        Returns:
+            A list of events in the state, possibly including the event itself, and
+            a list of events in the auth chain for the given event.
+        """
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = await self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
+
+        desired_events = set(state_event_ids + auth_event_ids)
+
+        if include_event_in_state:
+            desired_events.add(event_id)
+
+        event_map = await self._get_events_from_store_or_dest(
+            destination, room_id, desired_events
+        )
+
+        failed_to_fetch = desired_events - event_map.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s %s",
+                event_id,
+                failed_to_fetch,
+            )
+
+        remote_state = [
+            event_map[e_id] for e_id in state_event_ids if e_id in event_map
+        ]
+
+        if include_event_in_state:
+            remote_event = event_map.get(event_id)
+            if not remote_event:
+                raise Exception("Unable to get missing prev_event %s" % (event_id,))
+            if remote_event.is_state():
+                remote_state.append(remote_event)
+
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+        auth_chain.sort(key=lambda e: e.depth)
+
+        return remote_state, auth_chain
+
+    async def _get_events_from_store_or_dest(
+        self, destination: str, room_id: str, event_ids: Iterable[str]
+    ) -> Dict[str, EventBase]:
+        """Fetch events from a remote destination, checking if we already have them.
+
+        Args:
+            destination
+            room_id
+            event_ids
+
+        Returns:
+            map from event_id to event
+        """
+        fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+
+        missing_events = set(event_ids) - fetched_events.keys()
+
+        if not missing_events:
+            return fetched_events
+
+        logger.debug(
+            "Fetching unknown state/auth events %s for room %s",
+            missing_events,
+            event_ids,
+        )
+
+        room_version = await self.store.get_room_version(room_id)
+
+        # XXX 20 requests at once? really?
+        for batch in batch_iter(missing_events, 20):
+            deferreds = [
+                run_in_background(
+                    self.federation_client.get_pdu,
+                    destinations=[destination],
+                    event_id=e_id,
+                    room_version=room_version,
+                )
+                for e_id in batch
+            ]
+
+            res = await make_deferred_yieldable(
+                defer.DeferredList(deferreds, consumeErrors=True)
+            )
+            for success, result in res:
+                if success and result:
+                    fetched_events[result.event_id] = result
+
+        return fetched_events
+
+    async def _process_received_pdu(self, origin, event, state, auth_chain):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
         """
@@ -599,7 +691,7 @@ class FederationHandler(BaseHandler):
         if auth_chain:
             event_ids |= {e.event_id for e in auth_chain}
 
-        seen_ids = yield self.store.have_seen_events(event_ids)
+        seen_ids = await self.store.have_seen_events(event_ids)
 
         if state and auth_chain is not None:
             # If we have any state or auth_chain given to us by the replication
@@ -626,18 +718,18 @@ class FederationHandler(BaseHandler):
                 event_id,
                 [e.event.event_id for e in event_infos],
             )
-            yield self._handle_new_events(origin, event_infos)
+            await self._handle_new_events(origin, event_infos)
 
         try:
-            context = yield self._handle_new_event(origin, event, state=state)
+            context = await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        room = yield self.store.get_room(room_id)
+        room = await self.store.get_room(room_id)
 
         if not room:
             try:
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=room_id, room_creator_user_id="", is_public=False
                 )
             except StoreError:
@@ -650,11 +742,11 @@ class FederationHandler(BaseHandler):
                 # changing their profile info.
                 newly_joined = True
 
-                prev_state_ids = yield context.get_prev_state_ids(self.store)
+                prev_state_ids = await context.get_prev_state_ids(self.store)
 
                 prev_state_id = prev_state_ids.get((event.type, event.state_key))
                 if prev_state_id:
-                    prev_state = yield self.store.get_event(
+                    prev_state = await self.store.get_event(
                         prev_state_id, allow_none=True
                     )
                     if prev_state and prev_state.membership == Membership.JOIN:
@@ -662,11 +754,10 @@ class FederationHandler(BaseHandler):
 
                 if newly_joined:
                     user = UserID.from_string(event.state_key)
-                    yield self.user_joined_room(user, room_id)
+                    await self.user_joined_room(user, room_id)
 
     @log_function
-    @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(self, dest, room_id, limit, extremities):
         """ Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -683,9 +774,9 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version(room_id)
 
-        events = yield self.federation_client.backfill(
+        events = await self.federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
 
@@ -700,7 +791,7 @@ class FederationHandler(BaseHandler):
         #     self._sanity_check_event(ev)
 
         # Don't bother processing events we already have.
-        seen_events = yield self.store.have_events_in_timeline(
+        seen_events = await self.store.have_events_in_timeline(
             set(e.event_id for e in events)
         )
 
@@ -723,7 +814,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.federation_client.get_state_for_room(
+            state, auth = await self._get_state_for_room(
                 destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
@@ -748,7 +839,7 @@ class FederationHandler(BaseHandler):
         # We repeatedly do this until we stop finding new auth events.
         while missing_auth - failed_to_fetch:
             logger.info("Missing auth for backfill: %r", missing_auth)
-            ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+            ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
             auth_events.update(ret_events)
 
             required_auth.update(
@@ -762,7 +853,7 @@ class FederationHandler(BaseHandler):
                     missing_auth - failed_to_fetch,
                 )
 
-                results = yield make_deferred_yieldable(
+                results = await make_deferred_yieldable(
                     defer.gatherResults(
                         [
                             run_in_background(
@@ -789,7 +880,7 @@ class FederationHandler(BaseHandler):
 
                 failed_to_fetch = missing_auth - set(auth_events)
 
-        seen_events = yield self.store.have_seen_events(
+        seen_events = await self.store.have_seen_events(
             set(auth_events.keys()) | set(state_events.keys())
         )
 
@@ -851,7 +942,7 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        yield self._handle_new_events(dest, ev_infos, backfilled=True)
+        await self._handle_new_events(dest, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -867,16 +958,15 @@ class FederationHandler(BaseHandler):
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            yield self._handle_new_event(dest, event, backfilled=True)
+            await self._handle_new_event(dest, event, backfilled=True)
 
         return events
 
-    @defer.inlineCallbacks
-    def maybe_backfill(self, room_id, current_depth):
+    async def maybe_backfill(self, room_id, current_depth):
         """Checks the database to see if we should backfill before paginating,
         and if so do.
         """
-        extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+        extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
 
         if not extremities:
             logger.debug("Not backfilling as no extremeties found.")
@@ -908,15 +998,17 @@ class FederationHandler(BaseHandler):
         #   state *before* the event, ignoring the special casing certain event
         #   types have.
 
-        forward_events = yield self.store.get_successor_events(list(extremities))
+        forward_events = await self.store.get_successor_events(list(extremities))
 
-        extremities_events = yield self.store.get_events(
-            forward_events, check_redacted=False, get_prev_content=False
+        extremities_events = await self.store.get_events(
+            forward_events,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+            get_prev_content=False,
         )
 
         # We set `check_history_visibility_only` as we might otherwise get false
         # positives from users having been erased.
-        filtered_extremities = yield filter_events_for_server(
+        filtered_extremities = await filter_events_for_server(
             self.storage,
             self.server_name,
             list(extremities_events.values()),
@@ -946,7 +1038,7 @@ class FederationHandler(BaseHandler):
         # First we try hosts that are already in the room
         # TODO: HEURISTIC ALERT.
 
-        curr_state = yield self.state_handler.get_current_state(room_id)
+        curr_state = await self.state_handler.get_current_state(room_id)
 
         def get_domains_from_state(state):
             """Get joined domains from state
@@ -985,12 +1077,11 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        @defer.inlineCallbacks
-        def try_backfill(domains):
+        async def try_backfill(domains):
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
-                    yield self.backfill(
+                    await self.backfill(
                         dom, room_id, limit=100, extremities=extremities
                     )
                     # If this succeeded then we probably already have the
@@ -1021,7 +1112,7 @@ class FederationHandler(BaseHandler):
 
             return False
 
-        success = yield try_backfill(likely_domains)
+        success = await try_backfill(likely_domains)
         if success:
             return True
 
@@ -1035,7 +1126,7 @@ class FederationHandler(BaseHandler):
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
-        states = yield make_deferred_yieldable(
+        states = await make_deferred_yieldable(
             defer.gatherResults(
                 [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
             )
@@ -1045,7 +1136,7 @@ class FederationHandler(BaseHandler):
         # event_ids.
         states = dict(zip(event_ids, [s.state for s in states]))
 
-        state_map = yield self.store.get_events(
+        state_map = await self.store.get_events(
             [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
             get_prev_content=False,
         )
@@ -1061,7 +1152,7 @@ class FederationHandler(BaseHandler):
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
-            success = yield try_backfill(
+            success = await try_backfill(
                 [dom for dom, _ in likely_domains if dom not in tried_domains]
             )
             if success:
@@ -1210,7 +1301,7 @@ class FederationHandler(BaseHandler):
             # Check whether this room is the result of an upgrade of a room we already know
             # about. If so, migrate over user information
             predecessor = yield self.store.get_room_predecessor(room_id)
-            if not predecessor:
+            if not predecessor or not isinstance(predecessor.get("room_id"), str):
                 return
             old_room_id = predecessor["room_id"]
             logger.debug(
@@ -1238,8 +1329,7 @@ class FederationHandler(BaseHandler):
 
         return True
 
-    @defer.inlineCallbacks
-    def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(self, room_queue):
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
@@ -1255,7 +1345,7 @@ class FederationHandler(BaseHandler):
                     p.room_id,
                 )
                 with nested_logging_context(p.event_id):
-                    yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
                 logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1453,7 +1543,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
         origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts, room_id, user_id, "leave", content=content,
+            target_hosts, room_id, user_id, "leave", content=content
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
@@ -2814,7 +2904,7 @@ class FederationHandler(BaseHandler):
                 room_id=room_id, user_id=user.to_string(), change="joined"
             )
         else:
-            return user_joined_room(self.distributor, user, room_id)
+            return defer.succeed(user_joined_room(self.distributor, user, room_id))
 
     @defer.inlineCallbacks
     def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..bf9add7fe2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
 from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
             if event.type == EventTypes.Redaction:
                 original_event = yield self.store.get_event(
                     event.redacts,
-                    check_redacted=False,
+                    redact_behaviour=EventRedactBehaviour.AS_IS,
                     get_prev_content=False,
                     allow_rejected=False,
                     allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
         if event.type == EventTypes.Redaction:
             original_event = yield self.store.get_event(
                 event.redacts,
-                check_redacted=False,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
                 get_prev_content=False,
                 allow_rejected=False,
                 allow_none=True,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
 
             await self.storage.purge_events.purge_room(room_id)
 
-    @defer.inlineCallbacks
-    def get_messages(
+    async def get_messages(
         self,
         requester,
         room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
             room_token = pagin_config.from_token.room_key
         else:
             pagin_config.from_token = (
-                yield self.hs.get_event_sources().get_current_token_for_pagination()
+                await self.hs.get_event_sources().get_current_token_for_pagination()
             )
             room_token = pagin_config.from_token.room_key
 
@@ -319,11 +318,11 @@ class PaginationHandler(object):
 
         source_config = pagin_config.get_source_config("room")
 
-        with (yield self.pagination_lock.read(room_id)):
+        with (await self.pagination_lock.read(room_id)):
             (
                 membership,
                 member_event_id,
-            ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+            ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
 
             if source_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
                 if room_token.topological:
                     max_topo = room_token.topological
                 else:
-                    max_topo = yield self.store.get_max_topological_token(
+                    max_topo = await self.store.get_max_topological_token(
                         room_id, room_token.stream
                     )
 
@@ -339,18 +338,18 @@ class PaginationHandler(object):
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
-                    leave_token = yield self.store.get_topological_token_for_event(
+                    leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
                     leave_token = RoomStreamToken.parse(leave_token)
                     if leave_token.topological < max_topo:
                         source_config.from_key = str(leave_token)
 
-                yield self.hs.get_handlers().federation_handler.maybe_backfill(
+                await self.hs.get_handlers().federation_handler.maybe_backfill(
                     room_id, max_topo
                 )
 
-            events, next_key = yield self.store.paginate_room_events(
+            events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
                 from_key=source_config.from_key,
                 to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
             if event_filter:
                 events = event_filter.filter(events)
 
-            events = yield filter_events_for_client(
+            events = await filter_events_for_client(
                 self.storage, user_id, events, is_peeking=(member_event_id is None)
             )
 
@@ -385,19 +384,19 @@ class PaginationHandler(object):
                 (EventTypes.Member, event.sender) for event in events
             )
 
-            state_ids = yield self.state_store.get_state_ids_for_event(
+            state_ids = await self.state_store.get_state_ids_for_event(
                 events[0].event_id, state_filter=state_filter
             )
 
             if state_ids:
-                state = yield self.store.get_events(list(state_ids.values()))
+                state = await self.store.get_events(list(state_ids.values()))
                 state = state.values()
 
         time_now = self.clock.time_msec()
 
         chunk = {
             "chunk": (
-                yield self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     events, time_now, as_client_event=as_client_event
                 )
             ),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
         }
 
         if state:
-            chunk["state"] = yield self._event_serializer.serialize_events(
+            chunk["state"] = await self._event_serializer.serialize_events(
                 state, time_now, as_client_event=as_client_event
             )
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.storage.state import StateFilter
 from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
+        self.auth = hs.get_auth()
 
     @defer.inlineCallbacks
     def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
             room_id (str): id of the room to search through.
 
         Returns:
-            Deferred[iterable[unicode]]: predecessor room ids
+            Deferred[iterable[str]]: predecessor room ids
         """
 
         historical_room_ids = []
 
-        while True:
-            predecessor = yield self.store.get_room_predecessor(room_id)
+        # The initial room must have been known for us to get this far
+        predecessor = yield self.store.get_room_predecessor(room_id)
 
-            # If no predecessor, assume we've hit a dead end
+        while True:
             if not predecessor:
+                # We have reached the end of the chain of predecessors
+                break
+
+            if not isinstance(predecessor.get("room_id"), str):
+                # This predecessor object is malformed. Exit here
+                break
+
+            predecessor_room_id = predecessor["room_id"]
+
+            # Don't add it to the list until we have checked that we are in the room
+            try:
+                next_predecessor_room = yield self.store.get_room_predecessor(
+                    predecessor_room_id
+                )
+            except NotFoundError:
+                # The predecessor is not a known room, so we are done here
                 break
 
-            # Add predecessor's room ID
-            historical_room_ids.append(predecessor["room_id"])
+            historical_room_ids.append(predecessor_room_id)
 
-            # Scan through the old room for further predecessors
-            room_id = predecessor["room_id"]
+            # And repeat
+            predecessor = next_predecessor_room
 
         return historical_room_ids
 
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
 See doc/log_contexts.rst for details on how this works.
 """
 
+import inspect
 import logging
 import threading
 import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
 
 
 def make_deferred_yieldable(deferred):
-    """Given a deferred, make it follow the Synapse logcontext rules:
+    """Given a deferred (or coroutine), make it follow the Synapse logcontext
+    rules:
 
     If the deferred has completed (or is not actually a Deferred), essentially
     does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
+    if inspect.isawaitable(deferred):
+        # If we're given a coroutine we convert it to a deferred so that we
+        # run it and find out if it immediately finishes, it it does then we
+        # don't need to fiddle with log contexts at all and can return
+        # immediately.
+        deferred = defer.ensureDeferred(deferred)
+
     if not isinstance(deferred, defer.Deferred):
         return deferred
 
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 6ece1d6745..b91a528245 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ import six
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
 from ._slaved_id_tracker import SlavedIdTracker
@@ -35,8 +36,8 @@ def __func__(inp):
 
 
 class BaseSlavedStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = SlavedIdTracker(
                 db_conn, "cache_invalidation_stream", "stream_id"
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bc2f6a12ae..ebe94909cb 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
 from synapse.storage.data_stores.main.tags import TagsWorkerStore
+from synapse.storage.database import Database
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = SlavedIdTracker(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
 
-        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
+        super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index b4f58cea19..fbf996e33a 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.database import Database
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
@@ -21,8 +22,8 @@ from ._base import BaseSlavedStore
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedClientIpStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 9fb6c5c6ff..0c237c6e0f 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -16,13 +16,14 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_max_stream_id", "stream_id"
         )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index de50748c30..dc625e0d7a 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage.data_stores.main.devices import DeviceWorkerStore
 from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d0a0eaf75b..29f35b9915 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
 from synapse.storage.data_stores.main.stream import StreamWorkerStore
 from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -59,13 +60,13 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
         self._backfill_id_gen = SlavedIdTracker(
             db_conn, "events", "stream_ordering", step=-1
         )
 
-        super(SlavedEventStore, self).__init__(db_conn, hs)
+        super(SlavedEventStore, self).__init__(database, db_conn, hs)
 
     # Cached functions can't be accessed through a class instance so we need
     # to reach inside the __dict__ to extract them.
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 5c84ebd125..bcb0688954 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,13 +14,14 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.filtering import FilteringStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedFilteringStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
     get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 28a46edd28..69a4ae42f9 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.storage import DataStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore, __func__
@@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedGroupServerStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 747ced0c84..f552e7c972 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -15,6 +15,7 @@
 
 from synapse.storage import DataStore
 from synapse.storage.data_stores.main.presence import PresenceStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore, __func__
@@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPresenceStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPresenceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
         self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 3655f05e54..eebd5a1fb6 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,17 +15,18 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.storage.database import Database
 
 from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._push_rules_stream_id_gen = SlavedIdTracker(
             db_conn, "push_rules_stream", "stream_id"
         )
-        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
+        super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
 
     def get_push_rules_stream_token(self):
         return (
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index b4331d0799..f22c2d44a3 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,14 +15,15 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPusherStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPusherStore, self).__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 43d823c601..d40dc6e1f5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
+        super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index d9ad386b28..3a20f45316 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,14 +14,15 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.room import RoomWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1eac8a44c5..e7fe50ed72 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -15,6 +15,7 @@
 
 """ This module contains REST servlets to do with profile: /profile/<paths> """
 
+from synapse.api.errors import Codes, SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.types import UserID
@@ -103,11 +104,15 @@ class ProfileAvatarURLRestServlet(RestServlet):
 
         content = parse_json_object_from_request(request)
         try:
-            new_name = content["avatar_url"]
-        except Exception:
-            return 400, "Unable to parse name"
-
-        await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
+            new_avatar_url = content["avatar_url"]
+        except KeyError:
+            raise SynapseError(
+                400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM
+            )
+
+        await self.profile_handler.set_avatar_url(
+            user, requester, new_avatar_url, is_admin
+        )
 
         return 200, {}
 
diff --git a/synapse/server.py b/synapse/server.py
index be9af7f986..2db3dab221 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -238,8 +238,7 @@ class HomeServer(object):
     def setup(self):
         logger.info("Setting up.")
         with self.get_db_conn() as conn:
-            datastore = self.DATASTORE_CLASS(conn, self)
-            self.datastores = DataStores(datastore, conn, self)
+            self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
             conn.commit()
         self.start_time = int(self.get_clock().time())
         logger.info("Finished setting up.")
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 139beef8ed..3e6d62eef1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.utils import log_function
 from synapse.state import v1, v2
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -645,7 +646,7 @@ class StateResolutionStore(object):
 
         return self.store.get_events(
             event_ids,
-            check_redacted=False,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
             get_prev_content=False,
             allow_rejected=allow_rejected,
         )
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8fb18203dc..ec89f645d4 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,15 +49,3 @@ class Storage(object):
         self.persistence = EventsPersistenceStorage(hs, stores)
         self.purge_events = PurgeEventsStorage(hs, stores)
         self.state = StateGroupStorage(hs, stores)
-
-
-def are_all_users_on_domain(txn, database_engine, domain):
-    sql = database_engine.convert_param_style(
-        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
-    )
-    pat = "%:" + domain
-    txn.execute(sql, (pat,))
-    num_not_matching = txn.fetchall()[0][0]
-    if num_not_matching == 0:
-        return True
-    return False
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7e27d4e97..b7637b5dc0 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,11 +37,11 @@ class SQLBaseStore(object):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = hs.database_engine
-        self.db = Database(hs)  # In future this will be passed in
+        self.db = database
         self.rand = random.SystemRandom()
 
     def _invalidate_state_caches(self, room_id, members_changed):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index a9a13a2658..4f97fd5ab6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -379,7 +379,7 @@ class BackgroundUpdater(object):
             logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
-        if isinstance(self.db.database_engine, engines.PostgresEngine):
+        if isinstance(self.db.engine, engines.PostgresEngine):
             runner = create_index_psql
         elif psql_only:
             runner = None
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cb184a98cc..cafedd5c0d 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,6 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.database import Database
+from synapse.storage.prepare_database import prepare_database
+
 
 class DataStores(object):
     """The various data stores.
@@ -20,7 +23,14 @@ class DataStores(object):
     These are low level interfaces to physical databases.
     """
 
-    def __init__(self, main_store, db_conn, hs):
-        # Note we pass in the main store here as workers use a different main
+    def __init__(self, main_store_class, db_conn, hs):
+        # Note we pass in the main store class here as workers use a different main
         # store.
-        self.main = main_store
+        database = Database(hs)
+
+        # Check that db is correctly configured.
+        database.engine.check_database(db_conn.cursor())
+
+        prepare_database(db_conn, database.engine, config=hs.config)
+
+        self.main = main_store_class(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 6adb8adb04..c577c0df5f 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -20,6 +20,7 @@ import logging
 import time
 
 from synapse.api.constants import PresenceState
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     ChainedIdGenerator,
@@ -111,10 +112,20 @@ class DataStore(
     RelationsStore,
     CacheInvalidationStore,
 ):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
+        self.database_engine = database.engine
+
+        all_users_native = are_all_users_on_domain(
+            db_conn.cursor(), database.engine, hs.hostname
+        )
+        if not all_users_native:
+            raise Exception(
+                "Found users in database not native to %s!\n"
+                "You cannot changed a synapse server_name after it's been configured"
+                % (hs.hostname,)
+            )
 
         self._stream_id_gen = StreamIdGenerator(
             db_conn,
@@ -169,7 +180,7 @@ class DataStore(
         else:
             self._cache_id_gen = None
 
-        super(DataStore, self).__init__(db_conn, hs)
+        super(DataStore, self).__init__(database, db_conn, hs)
 
         self._presence_on_startup = self._get_active_presence(db_conn)
 
@@ -554,3 +565,15 @@ class DataStore(
             retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
             desc="search_users",
         )
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+    sql = database_engine.convert_param_style(
+        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+    )
+    pat = "%:" + domain
+    txn.execute(sql, (pat,))
+    num_not_matching = txn.fetchall()[0][0]
+    if num_not_matching == 0:
+        return True
+    return False
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 1a3dd7be8d..46b494b334 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+        super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
 
     @abc.abstractmethod
     def get_max_account_data_stream_id(self):
@@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore):
 
 
 class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = StreamIdGenerator(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
 
-        super(AccountDataStore, self).__init__(db_conn, hs)
+        super(AccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         """Get the current max stream id for the private user data stream
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 6b2e12719c..b2f39649fd 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
@@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache):
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
-        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+        super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 7b470a58f1..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
@@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000
 
 
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             "user_ips_device_index",
@@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
         )
 
-        super(ClientIpStore, self).__init__(db_conn, hs)
+        super(ClientIpStore, self).__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.user_ips_max_age
 
@@ -450,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self.db.simple_upsert_txn(
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
-                        values={
+                        updatevalues={
                             "user_agent": user_agent,
                             "last_seen": last_seen,
                             "ip": ip,
                         },
-                        lock=False,
                     )
             except Exception as e:
                 # Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 3c9f09301a..85cfa16850 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 logger = logging.getLogger(__name__)
@@ -210,8 +211,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             "device_inbox_stream_index",
@@ -241,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 91ddaf137e..9a828231c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -31,6 +31,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.types import get_verify_key_from_cross_signing_key
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import (
@@ -642,8 +643,8 @@ class DeviceWorkerStore(SQLBaseStore):
 
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             "device_lists_stream_idx",
@@ -692,8 +693,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 31d2e8eb28..1f517e8fad 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -491,8 +492,8 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, db_conn, hs):
-        super(EventFederationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventFederationStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_update_handler(
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index eec054cd48..9988a6d3fc 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
         self.stream_ordering_month_ago = None
@@ -611,8 +612,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             self.EPA_HIGHLIGHT_INDEX,
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d644c82784..998bba1aad 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -41,6 +41,7 @@ from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.data_stores.main.event_federation import EventFederationStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.database import Database
 from synapse.types import RoomStreamToken, get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -95,8 +96,8 @@ def _retry_on_integrity_error(func):
 class EventsStore(
     StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(EventsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsStore, self).__init__(database, db_conn, hs)
 
         # Collect metrics on the number of forward extremities that exist.
         # Counter of number of extremities to count
@@ -1038,20 +1039,25 @@ class EventsStore(
             },
         )
 
-    @defer.inlineCallbacks
-    def _censor_redactions(self):
+    async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
 
         By censor we mean update the event_json table with the redacted event.
-
-        Returns:
-            Deferred
         """
 
         if self.hs.config.redaction_retention_period is None:
             return
 
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
         before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
 
         # We fetch all redactions that:
@@ -1073,15 +1079,15 @@ class EventsStore(
             LIMIT ?
         """
 
-        rows = yield self.db.execute(
+        rows = await self.db.execute(
             "_censor_redactions_fetch", None, sql, before_ts, 100
         )
 
         updates = []
 
         for redaction_id, event_id in rows:
-            redaction_event = yield self.get_event(redaction_id, allow_none=True)
-            original_event = yield self.get_event(
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
                 event_id, allow_rejected=True, allow_none=True
             )
 
@@ -1114,7 +1120,7 @@ class EventsStore(
                     updatevalues={"have_censored": True},
                 )
 
-        yield self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
 
     def _censor_event_txn(self, txn, event_id, pruned_json):
         """Censor an event by replacing its JSON in the event_json table with the
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index cb1fc30c31..5177b71016 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventContentFields
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
@@ -33,8 +34,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
     DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
 
-    def __init__(self, db_conn, hs):
-        super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
@@ -89,6 +90,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             "event_store_labels", self._event_store_labels
         )
 
+        self.db.updates.register_background_index_update(
+            "redactions_have_censored_ts_idx",
+            index_name="redactions_have_censored_ts",
+            table="redactions",
+            columns=["received_ts"],
+            where_clause="NOT have_censored",
+        )
+
     @defer.inlineCallbacks
     def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e041fc5eac..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
+from typing import List, Optional
 
 from canonicaljson import json
+from constantly import NamedConstant, Names
 
 from twisted.internet import defer
 
@@ -33,6 +35,7 @@ from synapse.events.utils import prune_event
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import Cache
@@ -54,9 +57,19 @@ EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
+class EventRedactBehaviour(Names):
+    """
+    What to do when retrieving a redacted event from the database.
+    """
+
+    AS_IS = NamedConstant()
+    REDACT = NamedConstant()
+    BLOCK = NamedConstant()
+
+
 class EventsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._get_event_cache = Cache(
             "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
@@ -124,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_event(
         self,
-        event_id,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
-        allow_none=False,
-        check_room_id=None,
+        event_id: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: bool = False,
+        check_room_id: Optional[str] = None,
     ):
         """Get an event from the database by event_id.
 
         Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_id: The event_id of the event to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
+            allow_rejected: If True return rejected events.
+            allow_none: If True, return None if no event found, if
                 False throw a NotFoundError
-            check_room_id (str|None): if not None, check the room of the found event.
+            check_room_id: if not None, check the room of the found event.
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
@@ -153,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         events = yield self.get_events_as_list(
             [event_id],
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -172,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible
+                values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True return rejected events.
 
         Returns:
             Deferred : Dict from event_id to event.
         """
         events = yield self.get_events_as_list(
             event_ids,
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -202,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events_as_list(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True, return rejected events.
 
         Returns:
             Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -318,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
                     # Update the cache to save doing the checks again.
                     entry.event.internal_metadata.recheck_redaction = False
 
-            if check_redacted and entry.redacted_event:
-                event = entry.redacted_event
-            else:
-                event = entry.event
+            event = entry.event
+
+            if entry.redacted_event:
+                if redact_behaviour == EventRedactBehaviour.BLOCK:
+                    # Skip this event
+                    continue
+                elif redact_behaviour == EventRedactBehaviour.REDACT:
+                    event = entry.redacted_event
 
             events.append(event)
 
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 03c9c6f8ae..80ca36dedf 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -13,11 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+            database, db_conn, hs
+        )
 
         self.db.updates.register_background_index_update(
             update_name="local_media_repository_url_idx",
@@ -31,8 +34,8 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 34bf3a1880..27158534cb 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -17,6 +17,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersStore(SQLBaseStore):
-    def __init__(self, dbconn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(None, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
         # Do not add more reserved users than the total allowable number
         self.db.new_transaction(
-            dbconn,
+            db_conn,
             "initialise_mau_threepids",
             [],
             [],
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index de682cc63a..5ba13aa973 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -72,8 +73,8 @@ class PushRulesWorkerStore(
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
         push_rules_prefill, push_rules_id = self.db.get_cache_dict(
             db_conn,
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index ac2d45bd5c..96e54d145e 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -315,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
 
 class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = StreamIdGenerator(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(db_conn, hs)
+        super(ReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 1ef143c6d8..5e8ecac0ea 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,7 @@ from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -36,8 +37,8 @@ logger = logging.getLogger(__name__)
 
 
 class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
@@ -794,8 +795,8 @@ class RegistrationWorkerStore(SQLBaseStore):
 
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
@@ -920,8 +921,8 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationStore, self).__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
 
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index da42dae243..aa476d0fbf 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database
 from synapse.types import ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -45,6 +46,11 @@ RatelimitOverride = collections.namedtuple(
 
 
 class RoomWorkerStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+
+        self.config = hs.config
+
     def get_room(self, room_id):
         """Retrieve a room.
 
@@ -361,8 +367,8 @@ class RoomWorkerStore(SQLBaseStore):
 
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -440,8 +446,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
 
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 929f6b0d39..92e3b9c512 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -32,6 +32,7 @@ from synapse.storage._base import (
     make_in_list_sql_clause,
 )
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
     GetRoomsForUserWithStreamOrdering,
@@ -54,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
 
         # Is the current_state_events.membership up to date? Or is the
         # background update still running?
@@ -835,8 +836,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
         self.db.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
         )
@@ -991,8 +992,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
index fe51b02309..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -14,4 +14,3 @@
  */
 
 ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
-CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
index 77a5eca499..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -14,7 +14,9 @@
  */
 
 ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
-CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
 
 INSERT INTO background_updates (update_name, progress_json) VALUES
   ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index ffa1817e64..dfb46ee0f8 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,8 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
 logger = logging.getLogger(__name__)
@@ -42,8 +44,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, db_conn, hs):
-        super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         if not hs.config.enable_search:
             return
@@ -342,8 +344,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(SearchStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchStore, self).__init__(database, db_conn, hs)
 
     def store_event_search_txn(self, txn, event, key, value):
         """Add event to the search table
@@ -452,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 7d5a9f8128..dcc6b43cdf 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -28,6 +28,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
 from synapse.util.caches import get_cache_factor_for, intern_string
@@ -213,8 +214,8 @@ class StateGroupWorkerStore(
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
-    def __init__(self, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
         # event IDs for the state types in a given state group to avoid hammering
@@ -277,7 +278,7 @@ class StateGroupWorkerStore(
 
     @defer.inlineCallbacks
     def get_room_predecessor(self, room_id):
-        """Get the predecessor room of an upgraded room if one exists.
+        """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
 
         Args:
@@ -290,14 +291,22 @@ class StateGroupWorkerStore(
                     * room_id (str): The room ID of the predecessor room
                     * event_id (str): The ID of the tombstone event in the predecessor room
 
+                None if a predecessor key is not found, or is not a dictionary.
+
         Raises:
-            NotFoundError if the room is unknown
+            NotFoundError if the given room is unknown
         """
         # Retrieve the room's create event
         create_event = yield self.get_create_event_for_room(room_id)
 
-        # Return predecessor if present
-        return create_event.content.get("predecessor", None)
+        # Retrieve the predecessor key of the create event
+        predecessor = create_event.content.get("predecessor", None)
+
+        # Ensure the key is a dictionary
+        if not isinstance(predecessor, dict):
+            return None
+
+        return predecessor
 
     @defer.inlineCallbacks
     def get_create_event_for_room(self, room_id):
@@ -317,7 +326,7 @@ class StateGroupWorkerStore(
 
         # If we can't find the create event, assume we've hit a dead end
         if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id))
+            raise NotFoundError("Unknown room %s" % (room_id,))
 
         # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)
@@ -1029,8 +1038,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
 
-    def __init__(self, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
         self.db.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
             self._background_deduplicate_state,
@@ -1245,8 +1254,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateStore, self).__init__(database, db_conn, hs)
 
     def _store_event_state_mappings_txn(
         self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 40579bf965..7bc186e9a1 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.caches.descriptors import cached
 
@@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, db_conn, hs):
-        super(StatsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StatsStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 2ff8c57109..140da8dad6 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -47,6 +47,7 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -251,8 +252,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(StreamWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         events_max = self.get_room_max_stream_ordering()
         event_cache_prefill, min_event_val = self.db.get_cache_dict(
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index c0d155a43c..5b07c2fbc0 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
@@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
     """
 
-    def __init__(self, db_conn, hs):
-        super(TransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TransactionStore, self).__init__(database, db_conn, hs)
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 62ffb34b29..90c180ec6d 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.data_stores.main.state import StateFilter
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
@@ -37,8 +38,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
@@ -549,8 +550,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryStore, self).__init__(database, db_conn, hs)
 
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6843b7e7f8..ec19ae1d9d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -234,7 +234,7 @@ class Database(object):
         #   to watch it
         self._txn_perf_counters = PerformanceCounters()
 
-        self.database_engine = hs.database_engine
+        self.engine = hs.database_engine
 
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -242,10 +242,10 @@ class Database(object):
         # We add the user_directory_search table to the blacklist on SQLite
         # because the existing search table does not have an index, making it
         # unsafe to use native upserts.
-        if isinstance(self.database_engine, Sqlite3Engine):
+        if isinstance(self.engine, Sqlite3Engine):
             self._unsafe_to_upsert_tables.add("user_directory_search")
 
-        if self.database_engine.can_native_upsert:
+        if self.engine.can_native_upsert:
             # Check ASAP (and then later, every 1s) to see if we have finished
             # background updates of tables that aren't safe to update.
             self._clock.call_later(
@@ -331,7 +331,7 @@ class Database(object):
                 cursor = LoggingTransaction(
                     conn.cursor(),
                     name,
-                    self.database_engine,
+                    self.engine,
                     after_callbacks,
                     exception_callbacks,
                 )
@@ -339,7 +339,7 @@ class Database(object):
                     r = func(cursor, *args, **kwargs)
                     conn.commit()
                     return r
-                except self.database_engine.module.OperationalError as e:
+                except self.engine.module.OperationalError as e:
                     # This can happen if the database disappears mid
                     # transaction.
                     logger.warning(
@@ -353,20 +353,20 @@ class Database(object):
                         i += 1
                         try:
                             conn.rollback()
-                        except self.database_engine.module.Error as e1:
+                        except self.engine.module.Error as e1:
                             logger.warning(
                                 "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
                             )
                         continue
                     raise
-                except self.database_engine.module.DatabaseError as e:
-                    if self.database_engine.is_deadlock(e):
+                except self.engine.module.DatabaseError as e:
+                    if self.engine.is_deadlock(e):
                         logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
                         if i < N:
                             i += 1
                             try:
                                 conn.rollback()
-                            except self.database_engine.module.Error as e1:
+                            except self.engine.module.Error as e1:
                                 logger.warning(
                                     "[TXN EROLL] {%s} %s",
                                     name,
@@ -494,7 +494,7 @@ class Database(object):
                 sql_scheduling_timer.observe(sched_duration_sec)
                 context.add_database_scheduled(sched_duration_sec)
 
-                if self.database_engine.is_connection_closed(conn):
+                if self.engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
@@ -561,7 +561,7 @@ class Database(object):
         """
         try:
             yield self.runInteraction(desc, self.simple_insert_txn, table, values)
-        except self.database_engine.module.IntegrityError:
+        except self.engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
             if not or_ignore:
@@ -660,7 +660,7 @@ class Database(object):
                     lock=lock,
                 )
                 return result
-            except self.database_engine.module.IntegrityError as e:
+            except self.engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
                     # don't retry forever, because things other than races
@@ -692,10 +692,7 @@ class Database(object):
             upserts return True if a new entry was created, False if an existing
             one was updated.
         """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
@@ -726,7 +723,7 @@ class Database(object):
         """
         # We need to lock the table :(, unless we're *really* careful
         if lock:
-            self.database_engine.lock_table(txn, table)
+            self.engine.lock_table(txn, table)
 
         def _getwhere(key):
             # If the value we're passing in is None (aka NULL), we need to use
@@ -828,10 +825,7 @@ class Database(object):
         Returns:
             None
         """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
                 txn, table, key_names, key_values, value_names, value_values
             )
@@ -1301,7 +1295,7 @@ class Database(object):
             "limit": limit,
         }
 
-        sql = self.database_engine.convert_param_style(sql)
+        sql = self.engine.convert_param_style(sql)
 
         txn = db_conn.cursor()
         txn.execute(sql, (int(max_value),))