summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/homeserver.py6
-rw-r--r--synapse/config/saml2_config.py186
-rw-r--r--synapse/event_auth.py8
-rw-r--r--synapse/federation/federation_client.py191
-rw-r--r--synapse/federation/federation_server.py15
-rw-r--r--synapse/federation/transport/client.py33
-rw-r--r--synapse/federation/transport/server.py32
-rw-r--r--synapse/handlers/federation.py112
-rw-r--r--synapse/handlers/initial_sync.py19
-rw-r--r--synapse/handlers/message.py5
-rw-r--r--synapse/handlers/saml_handler.py198
-rw-r--r--synapse/handlers/search.py34
-rw-r--r--synapse/logging/context.py11
-rw-r--r--synapse/state/__init__.py3
-rw-r--r--synapse/storage/data_stores/main/client_ips.py8
-rw-r--r--synapse/storage/data_stores/main/events_worker.py97
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.md13
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.txt19
-rw-r--r--synapse/storage/data_stores/main/search.py8
-rw-r--r--synapse/storage/data_stores/main/state.py18
-rw-r--r--synapse/util/caches/snapshot_cache.py94
21 files changed, 700 insertions, 410 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index df65d0a989..032010600a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -519,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
     # Database version
     #
 
-    stats["database_engine"] = hs.database_engine.module.__name__
-    stats["database_server_version"] = hs.database_engine.server_version
+    # This only reports info about the *main* database.
+    stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+    stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
     logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
     try:
         yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
+import logging
 
 from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
-    map_username_to_mxid_localpart,
-    mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
 
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+    "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
 
 def _dict_merge(merge_dict, into_dict):
     """Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
-        self.saml2_mxid_source_attribute = saml2_config.get(
-            "mxid_source_attribute", "uid"
-        )
-
         self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
             "grandfathered_mxid_source_attribute", "uid"
         )
 
-        saml2_config_dict = self._default_saml_config_dict()
+        # user_mapping_provider may be None if the key is present but has no value
+        ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+        # Use the default user mapping provider if not set
+        ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+        # Ensure a config is present
+        ump_dict["config"] = ump_dict.get("config") or {}
+
+        if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+            # Load deprecated options for use by the default module
+            old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+            if old_mxid_source_attribute:
+                logger.warning(
+                    "The config option saml2_config.mxid_source_attribute is deprecated. "
+                    "Please use saml2_config.user_mapping_provider.config"
+                    ".mxid_source_attribute instead."
+                )
+                ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+            old_mxid_mapping = saml2_config.get("mxid_mapping")
+            if old_mxid_mapping:
+                logger.warning(
+                    "The config option saml2_config.mxid_mapping is deprecated. Please "
+                    "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+                )
+                ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+        # Retrieve an instance of the module's class
+        # Pass the config dictionary to the module for processing
+        (
+            self.saml2_user_mapping_provider_class,
+            self.saml2_user_mapping_provider_config,
+        ) = load_module(ump_dict)
+
+        # Ensure loaded user mapping module has defined all necessary methods
+        # Note parse_config() is already checked during the call to load_module
+        required_methods = [
+            "get_saml_attributes",
+            "saml_response_to_user_attributes",
+        ]
+        missing_methods = [
+            method
+            for method in required_methods
+            if not hasattr(self.saml2_user_mapping_provider_class, method)
+        ]
+        if missing_methods:
+            raise ConfigError(
+                "Class specified by saml2_config."
+                "user_mapping_provider.module is missing required "
+                "methods: %s" % (", ".join(missing_methods),)
+            )
+
+        # Get the desired saml auth response attributes from the module
+        saml2_config_dict = self._default_saml_config_dict(
+            *self.saml2_user_mapping_provider_class.get_saml_attributes(
+                self.saml2_user_mapping_provider_config
+            )
+        )
         _dict_merge(
             merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
         )
@@ -103,22 +159,27 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "5m")
         )
 
-        mapping = saml2_config.get("mxid_mapping", "hexencode")
-        try:
-            self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
-        except KeyError:
-            raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
-    def _default_saml_config_dict(self):
+    def _default_saml_config_dict(
+        self, required_attributes: set, optional_attributes: set
+    ):
+        """Generate a configuration dictionary with required and optional attributes that
+        will be needed to process new user registration
+
+        Args:
+            required_attributes: SAML auth response attributes that are
+                necessary to function
+            optional_attributes: SAML auth response attributes that can be used to add
+                additional information to Synapse user accounts, but are not required
+
+        Returns:
+            dict: A SAML configuration dictionary
+        """
         import saml2
 
         public_baseurl = self.public_baseurl
         if public_baseurl is None:
             raise ConfigError("saml2_config requires a public_baseurl to be set")
 
-        required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
-        optional_attributes = {"displayName"}
         if self.saml2_grandfathered_mxid_source_attribute:
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
         optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
           #
           #config_path: "%(config_dir_path)s/sp_conf.py"
 
-          # the lifetime of a SAML session. This defines how long a user has to
+          # The lifetime of a SAML session. This defines how long a user has to
           # complete the authentication process, if allow_unsolicited is unset.
           # The default is 5 minutes.
           #
           #saml_session_lifetime: 5m
 
-          # The SAML attribute (after mapping via the attribute maps) to use to derive
-          # the Matrix ID from. 'uid' by default.
+          # An external module can be provided here as a custom solution to
+          # mapping attributes returned from a saml provider onto a matrix user.
           #
-          #mxid_source_attribute: displayName
-
-          # The mapping system to use for mapping the saml attribute onto a matrix ID.
-          # Options include:
-          #  * 'hexencode' (which maps unpermitted characters to '=xx')
-          #  * 'dotreplace' (which replaces unpermitted characters with '.').
-          # The default is 'hexencode'.
-          #
-          #mxid_mapping: dotreplace
-
-          # In previous versions of synapse, the mapping from SAML attribute to MXID was
-          # always calculated dynamically rather than stored in a table. For backwards-
-          # compatibility, we will look for user_ids matching such a pattern before
-          # creating a new account.
+          user_mapping_provider:
+            # The custom module's class. Uncomment to use a custom module.
+            #
+            #module: mapping_provider.SamlMappingProvider
+
+            # Custom configuration values for the module. Below options are
+            # intended for the built-in provider, they should be changed if
+            # using a custom module. This section will be passed as a Python
+            # dictionary to the module's `parse_config` method.
+            #
+            config:
+              # The SAML attribute (after mapping via the attribute maps) to use
+              # to derive the Matrix ID from. 'uid' by default.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_source_attribute option. If that is still
+              # defined, its value will be used instead.
+              #
+              #mxid_source_attribute: displayName
+
+              # The mapping system to use for mapping the saml attribute onto a
+              # matrix ID.
+              #
+              # Options include:
+              #  * 'hexencode' (which maps unpermitted characters to '=xx')
+              #  * 'dotreplace' (which replaces unpermitted characters with
+              #     '.').
+              # The default is 'hexencode'.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_mapping option. If that is still defined, its
+              # value will be used instead.
+              #
+              #mxid_mapping: dotreplace
+
+          # In previous versions of synapse, the mapping from SAML attribute to
+          # MXID was always calculated dynamically rather than stored in a
+          # table. For backwards- compatibility, we will look for user_ids
+          # matching such a pattern before creating a new account.
           #
           # This setting controls the SAML attribute which will be used for this
-          # backwards-compatibility lookup. Typically it should be 'uid', but if the
-          # attribute maps are changed, it may be necessary to change it.
+          # backwards-compatibility lookup. Typically it should be 'uid', but if
+          # the attribute maps are changed, it may be necessary to change it.
           #
           # The default is 'uid'.
           #
@@ -241,23 +327,3 @@ class SAML2Config(Config):
         """ % {
             "config_dir_path": config_dir_path
         }
-
-
-DOT_REPLACE_PATTERN = re.compile(
-    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
-    username = username.lower()
-    username = DOT_REPLACE_PATTERN.sub(".", username)
-
-    # regular mxids aren't allowed to start with an underscore either
-    username = re.sub("^_", "", username)
-    return username
-
-
-MXID_MAPPER_MAP = {
-    "hexencode": map_username_to_mxid_localpart,
-    "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..c940b84470 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
     Returns:
          if the auth checks pass.
     """
+    assert isinstance(auth_events, dict)
+
     if do_size_check:
         _check_size_limits(event)
 
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
-    if auth_events is None:
-        # Oh, we don't know what the state of the room was, so we
-        # are trusting that this is allowed (at least for now)
-        logger.warning("Trusting event: %s", event.event_id)
-        return
-
     if event.type == EventTypes.Create:
         sender_domain = get_domain_from_id(event.sender)
         room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..af652a7659 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
 import itertools
 import logging
 
-from six.moves import range
-
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.utils import log_function
 from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,19 +308,12 @@ class FederationClient(FederationBase):
         return signed_pdu
 
     @defer.inlineCallbacks
-    @log_function
-    def get_state_for_room(self, destination, room_id, event_id):
-        """Requests all of the room state at a given event from a remote homeserver.
-
-        Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+    def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+        """Calls the /state_ids endpoint to fetch the state at a particular point
+        in the room, and the auth events for the given event
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            Tuple[List[str], List[str]]:  a tuple of (state event_ids, auth event_ids)
         """
         result = yield self.transport_layer.get_room_state_ids(
             destination, room_id, event_id=event_id
@@ -331,86 +322,12 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
-        fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-            destination, room_id, set(state_event_ids + auth_event_ids)
-        )
-
-        if failed_to_fetch:
-            logger.warning(
-                "Failed to fetch missing state/auth events for %s: %s",
-                room_id,
-                failed_to_fetch,
-            )
-
-        event_map = {ev.event_id: ev for ev in fetched_events}
+        if not isinstance(state_event_ids, list) or not isinstance(
+            auth_event_ids, list
+        ):
+            raise Exception("invalid response from /state_ids")
 
-        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-
-        auth_chain.sort(key=lambda e: e.depth)
-
-        return pdus, auth_chain
-
-    @defer.inlineCallbacks
-    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
-        """Fetch events from a remote destination, checking if we already have them.
-
-        Args:
-            destination (str)
-            room_id (str)
-            event_ids (list)
-
-        Returns:
-            Deferred: A deferred resolving to a 2-tuple where the first is a list of
-            events and the second is a list of event ids that we failed to fetch.
-        """
-        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
-        signed_events = list(seen_events.values())
-
-        failed_to_fetch = set()
-
-        missing_events = set(event_ids)
-        for k in seen_events:
-            missing_events.discard(k)
-
-        if not missing_events:
-            return signed_events, failed_to_fetch
-
-        logger.debug(
-            "Fetching unknown state/auth events %s for room %s",
-            missing_events,
-            event_ids,
-        )
-
-        room_version = yield self.store.get_room_version(room_id)
-
-        batch_size = 20
-        missing_events = list(missing_events)
-        for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i : i + batch_size])
-
-            deferreds = [
-                run_in_background(
-                    self.get_pdu,
-                    destinations=[destination],
-                    event_id=e_id,
-                    room_version=room_version,
-                )
-                for e_id in batch
-            ]
-
-            res = yield make_deferred_yieldable(
-                defer.DeferredList(deferreds, consumeErrors=True)
-            )
-            for success, result in res:
-                if success and result:
-                    signed_events.append(result)
-                    batch.discard(result.event_id)
-
-            # We removed all events we successfully fetched from `batch`
-            failed_to_fetch.update(batch)
-
-        return signed_events, failed_to_fetch
+        return state_event_ids, auth_event_ids
 
     @defer.inlineCallbacks
     @log_function
@@ -609,13 +526,7 @@ class FederationClient(FederationBase):
 
         @defer.inlineCallbacks
         def send_request(destination):
-            time_now = self._clock.time_msec()
-            _, content = yield self.transport_layer.send_join(
-                destination=destination,
-                room_id=pdu.room_id,
-                event_id=pdu.event_id,
-                content=pdu.get_pdu_json(time_now),
-            )
+            content = yield self._do_send_join(destination, pdu)
 
             logger.debug("Got content: %s", content)
 
@@ -683,6 +594,44 @@ class FederationClient(FederationBase):
         return self._try_destination_list("send_join", destinations, send_request)
 
     @defer.inlineCallbacks
+    def _do_send_join(self, destination, pdu):
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_join_v2(
+                destination=destination,
+                room_id=pdu.room_id,
+                event_id=pdu.event_id,
+                content=pdu.get_pdu_json(time_now),
+            )
+
+            return content
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                err = e.to_synapse_error()
+
+                # If we receive an error response that isn't a generic error, or an
+                # unrecognised endpoint error, we  assume that the remote understands
+                # the v2 invite API and this is a legitimate error.
+                if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+                    raise err
+            else:
+                raise e.to_synapse_error()
+
+        logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
+
+        resp = yield self.transport_layer.send_join_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+
+        # We expect the v1 API to respond with [200, content], so we only return the
+        # content.
+        return resp[1]
+
+    @defer.inlineCallbacks
     def send_invite(self, destination, room_id, event_id, pdu):
         room_version = yield self.store.get_room_version(room_id)
 
@@ -791,18 +740,50 @@ class FederationClient(FederationBase):
 
         @defer.inlineCallbacks
         def send_request(destination):
-            time_now = self._clock.time_msec()
-            _, content = yield self.transport_layer.send_leave(
+            content = yield self._do_send_leave(destination, pdu)
+
+            logger.debug("Got content: %s", content)
+            return None
+
+        return self._try_destination_list("send_leave", destinations, send_request)
+
+    @defer.inlineCallbacks
+    def _do_send_leave(self, destination, pdu):
+        time_now = self._clock.time_msec()
+
+        try:
+            content = yield self.transport_layer.send_leave_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
 
-            logger.debug("Got content: %s", content)
-            return None
+            return content
+        except HttpResponseException as e:
+            if e.code in [400, 404]:
+                err = e.to_synapse_error()
 
-        return self._try_destination_list("send_leave", destinations, send_request)
+                # If we receive an error response that isn't a generic error, or an
+                # unrecognised endpoint error, we  assume that the remote understands
+                # the v2 invite API and this is a legitimate error.
+                if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+                    raise err
+            else:
+                raise e.to_synapse_error()
+
+        logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
+
+        resp = yield self.transport_layer.send_leave_v1(
+            destination=destination,
+            room_id=pdu.room_id,
+            event_id=pdu.event_id,
+            content=pdu.get_pdu_json(time_now),
+        )
+
+        # We expect the v1 API to respond with [200, content], so we only return the
+        # content.
+        return resp[1]
 
     def get_public_rooms(
         self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 84d4eca041..d7ce333822 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -384,15 +384,10 @@ class FederationServer(FederationBase):
 
         res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
-        return (
-            200,
-            {
-                "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-                "auth_chain": [
-                    p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
-                ],
-            },
-        )
+        return {
+            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+            "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+        }
 
     async def on_make_leave_request(self, origin, room_id, user_id):
         origin_host, _ = parse_server_name(origin)
@@ -419,7 +414,7 @@ class FederationServer(FederationBase):
         pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         await self.handler.on_send_leave_request(origin, pdu)
-        return 200, {}
+        return {}
 
     async def on_event_auth(self, origin, room_id, event_id):
         with (await self._server_linearizer.queue((origin, room_id))):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 46dba84cac..198257414b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -243,7 +243,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_join(self, destination, room_id, event_id, content):
+    def send_join_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
@@ -254,7 +254,18 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_leave(self, destination, room_id, event_id, content):
+    def send_join_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination, path=path, data=content
+        )
+
+        return response
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_leave_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
@@ -272,6 +283,24 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def send_leave_v2(self, destination, room_id, event_id, content):
+        path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
+
+        response = yield self.client.put_json(
+            destination=destination,
+            path=path,
+            data=content,
+            # we want to do our best to send this through. The problem is
+            # that if it fails, we won't retry it later, so if the remote
+            # server was just having a momentary blip, the room will be out of
+            # sync.
+            ignore_backoff=True,
+        )
+
+        return response
+
+    @defer.inlineCallbacks
+    @log_function
     def send_invite_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fefc789c85..b4cbf23394 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
         return 200, content
 
 
-class FederationSendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServlet):
     PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
     async def on_PUT(self, origin, content, query, room_id, event_id):
         content = await self.handler.on_send_leave_request(origin, content, room_id)
+        return 200, (200, content)
+
+
+class FederationV2SendLeaveServlet(BaseFederationServlet):
+    PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+    PREFIX = FEDERATION_V2_PREFIX
+
+    async def on_PUT(self, origin, content, query, room_id, event_id):
+        content = await self.handler.on_send_leave_request(origin, content, room_id)
         return 200, content
 
 
@@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
         return await self.handler.on_event_auth(origin, context, event_id)
 
 
-class FederationSendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServlet):
+    PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+    async def on_PUT(self, origin, content, query, context, event_id):
+        # TODO(paul): assert that context/event_id parsed from path actually
+        #   match those given in content
+        content = await self.handler.on_send_join_request(origin, content, context)
+        return 200, (200, content)
+
+
+class FederationV2SendJoinServlet(BaseFederationServlet):
     PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
+    PREFIX = FEDERATION_V2_PREFIX
+
     async def on_PUT(self, origin, content, query, context, event_id):
         # TODO(paul): assert that context/event_id parsed from path actually
         #   match those given in content
@@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
     FederationMakeJoinServlet,
     FederationMakeLeaveServlet,
     FederationEventServlet,
-    FederationSendJoinServlet,
-    FederationSendLeaveServlet,
+    FederationV1SendJoinServlet,
+    FederationV2SendJoinServlet,
+    FederationV1SendLeaveServlet,
+    FederationV2SendLeaveServlet,
     FederationV1InviteServlet,
     FederationV2InviteServlet,
     FederationQueryAuthServlet,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..8f3c9d7702 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
 )
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -379,11 +380,9 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self.federation_client.get_state_for_room(
-                                origin, room_id, p
-                            )
+                            ) = yield self._get_state_for_room(origin, room_id, p)
 
-                            # we want the state *after* p; get_state_for_room returns the
+                            # we want the state *after* p; _get_state_for_room returns the
                             # state *before* p.
                             remote_event = yield self.federation_client.get_pdu(
                                 [origin], p, room_version, outlier=True
@@ -425,7 +424,7 @@ class FederationHandler(BaseHandler):
                     evs = yield self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
-                        check_redacted=False,
+                        redact_behaviour=EventRedactBehaviour.AS_IS,
                     )
                     event_map.update(evs)
 
@@ -584,6 +583,97 @@ class FederationHandler(BaseHandler):
                         raise
 
     @defer.inlineCallbacks
+    @log_function
+    def _get_state_for_room(self, destination, room_id, event_id):
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination (str): The remote homeserver to query for the state.
+            room_id (str): The id of the room we're interested in.
+            event_id (str): The id of the event we want the state at.
+
+        Returns:
+            Deferred[Tuple[List[EventBase], List[EventBase]]]:
+                A list of events in the state, and a list of events in the auth chain
+                for the given event.
+        """
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = yield self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
+
+        desired_events = set(state_event_ids + auth_event_ids)
+        event_map = yield self._get_events_from_store_or_dest(
+            destination, room_id, desired_events
+        )
+
+        failed_to_fetch = desired_events - event_map.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s: %s",
+                room_id,
+                failed_to_fetch,
+            )
+
+        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+
+        auth_chain.sort(key=lambda e: e.depth)
+
+        return pdus, auth_chain
+
+    @defer.inlineCallbacks
+    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+        """Fetch events from a remote destination, checking if we already have them.
+
+        Args:
+            destination (str)
+            room_id (str)
+            event_ids (Iterable[str])
+
+        Returns:
+            Deferred[dict[str, EventBase]]: A deferred resolving to a map
+            from event_id to event
+        """
+        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+
+        missing_events = set(event_ids) - fetched_events.keys()
+
+        if not missing_events:
+            return fetched_events
+
+        logger.debug(
+            "Fetching unknown state/auth events %s for room %s",
+            missing_events,
+            event_ids,
+        )
+
+        room_version = yield self.store.get_room_version(room_id)
+
+        # XXX 20 requests at once? really?
+        for batch in batch_iter(missing_events, 20):
+            deferreds = [
+                run_in_background(
+                    self.federation_client.get_pdu,
+                    destinations=[destination],
+                    event_id=e_id,
+                    room_version=room_version,
+                )
+                for e_id in batch
+            ]
+
+            res = yield make_deferred_yieldable(
+                defer.DeferredList(deferreds, consumeErrors=True)
+            )
+            for success, result in res:
+                if success and result:
+                    fetched_events[result.event_id] = result
+
+        return fetched_events
+
+    @defer.inlineCallbacks
     def _process_received_pdu(self, origin, event, state, auth_chain):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
@@ -723,7 +813,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.federation_client.get_state_for_room(
+            state, auth = yield self._get_state_for_room(
                 destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
@@ -911,7 +1001,9 @@ class FederationHandler(BaseHandler):
         forward_events = yield self.store.get_successor_events(list(extremities))
 
         extremities_events = yield self.store.get_events(
-            forward_events, check_redacted=False, get_prev_content=False
+            forward_events,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+            get_prev_content=False,
         )
 
         # We set `check_history_visibility_only` as we might otherwise get false
@@ -1210,7 +1302,7 @@ class FederationHandler(BaseHandler):
             # Check whether this room is the result of an upgrade of a room we already know
             # about. If so, migrate over user information
             predecessor = yield self.store.get_room_predecessor(room_id)
-            if not predecessor:
+            if not predecessor or not isinstance(predecessor.get("room_id"), str):
                 return
             old_room_id = predecessor["room_id"]
             logger.debug(
@@ -1453,7 +1545,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
         origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts, room_id, user_id, "leave", content=content,
+            target_hosts, room_id, user_id, "leave", content=content
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..73c110a92b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
+        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler):
             as_client_event,
             include_archived,
         )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
 
-        return self.snapshot_cache.set(
-            now_ms,
+        return self.snapshot_cache.wrap(
             key,
-            self._snapshot_all_rooms(
-                user_id, pagin_config, as_client_event, include_archived
-            ),
+            self._snapshot_all_rooms,
+            user_id,
+            pagin_config,
+            as_client_event,
+            include_archived,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..bf9add7fe2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
 from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
             if event.type == EventTypes.Redaction:
                 original_event = yield self.store.get_event(
                     event.redacts,
-                    check_redacted=False,
+                    redact_behaviour=EventRedactBehaviour.AS_IS,
                     get_prev_content=False,
                     allow_rejected=False,
                     allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
         if event.type == EventTypes.Redaction:
             original_event = yield self.store.get_event(
                 event.redacts,
-                check_redacted=False,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
                 get_prev_content=False,
                 allow_rejected=False,
                 allow_none=True,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.storage.state import StateFilter
 from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
+        self.auth = hs.get_auth()
 
     @defer.inlineCallbacks
     def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
             room_id (str): id of the room to search through.
 
         Returns:
-            Deferred[iterable[unicode]]: predecessor room ids
+            Deferred[iterable[str]]: predecessor room ids
         """
 
         historical_room_ids = []
 
-        while True:
-            predecessor = yield self.store.get_room_predecessor(room_id)
+        # The initial room must have been known for us to get this far
+        predecessor = yield self.store.get_room_predecessor(room_id)
 
-            # If no predecessor, assume we've hit a dead end
+        while True:
             if not predecessor:
+                # We have reached the end of the chain of predecessors
+                break
+
+            if not isinstance(predecessor.get("room_id"), str):
+                # This predecessor object is malformed. Exit here
+                break
+
+            predecessor_room_id = predecessor["room_id"]
+
+            # Don't add it to the list until we have checked that we are in the room
+            try:
+                next_predecessor_room = yield self.store.get_room_predecessor(
+                    predecessor_room_id
+                )
+            except NotFoundError:
+                # The predecessor is not a known room, so we are done here
                 break
 
-            # Add predecessor's room ID
-            historical_room_ids.append(predecessor["room_id"])
+            historical_room_ids.append(predecessor_room_id)
 
-            # Scan through the old room for further predecessors
-            room_id = predecessor["room_id"]
+            # And repeat
+            predecessor = next_predecessor_room
 
         return historical_room_ids
 
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
 See doc/log_contexts.rst for details on how this works.
 """
 
+import inspect
 import logging
 import threading
 import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
 
 
 def make_deferred_yieldable(deferred):
-    """Given a deferred, make it follow the Synapse logcontext rules:
+    """Given a deferred (or coroutine), make it follow the Synapse logcontext
+    rules:
 
     If the deferred has completed (or is not actually a Deferred), essentially
     does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
+    if inspect.isawaitable(deferred):
+        # If we're given a coroutine we convert it to a deferred so that we
+        # run it and find out if it immediately finishes, it it does then we
+        # don't need to fiddle with log contexts at all and can return
+        # immediately.
+        deferred = defer.ensureDeferred(deferred)
+
     if not isinstance(deferred, defer.Deferred):
         return deferred
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 139beef8ed..3e6d62eef1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.utils import log_function
 from synapse.state import v1, v2
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -645,7 +646,7 @@ class StateResolutionStore(object):
 
         return self.store.get_events(
             event_ids,
-            check_redacted=False,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
             get_prev_content=False,
             allow_rejected=allow_rejected,
         )
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self.db.simple_upsert_txn(
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
-                        values={
+                        updatevalues={
                             "user_agent": user_agent,
                             "last_seen": last_seen,
                             "ip": ip,
                         },
-                        lock=False,
                     )
             except Exception as e:
                 # Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 9ee117ce0f..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
+from typing import List, Optional
 
 from canonicaljson import json
+from constantly import NamedConstant, Names
 
 from twisted.internet import defer
 
@@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
+class EventRedactBehaviour(Names):
+    """
+    What to do when retrieving a redacted event from the database.
+    """
+
+    AS_IS = NamedConstant()
+    REDACT = NamedConstant()
+    BLOCK = NamedConstant()
+
+
 class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
@@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_event(
         self,
-        event_id,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
-        allow_none=False,
-        check_room_id=None,
+        event_id: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: bool = False,
+        check_room_id: Optional[str] = None,
     ):
         """Get an event from the database by event_id.
 
         Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_id: The event_id of the event to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
+            allow_rejected: If True return rejected events.
+            allow_none: If True, return None if no event found, if
                 False throw a NotFoundError
-            check_room_id (str|None): if not None, check the room of the found event.
+            check_room_id: if not None, check the room of the found event.
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
@@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         events = yield self.get_events_as_list(
             [event_id],
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible
+                values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True return rejected events.
 
         Returns:
             Deferred : Dict from event_id to event.
         """
         events = yield self.get_events_as_list(
             event_ids,
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events_as_list(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True, return rejected events.
 
         Returns:
             Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
                     # Update the cache to save doing the checks again.
                     entry.event.internal_metadata.recheck_redaction = False
 
-            if check_redacted and entry.redacted_event:
-                event = entry.redacted_event
-            else:
-                event = entry.event
+            event = entry.event
+
+            if entry.redacted_event:
+                if redact_behaviour == EventRedactBehaviour.BLOCK:
+                    # Skip this event
+                    continue
+                elif redact_behaviour == EventRedactBehaviour.REDACT:
+                    event = entry.redacted_event
 
             events.append(event)
 
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 4eec2fae5e..dfb46ee0f8 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
@@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9ef7b48c74..dcc6b43cdf 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -278,7 +278,7 @@ class StateGroupWorkerStore(
 
     @defer.inlineCallbacks
     def get_room_predecessor(self, room_id):
-        """Get the predecessor room of an upgraded room if one exists.
+        """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
 
         Args:
@@ -291,14 +291,22 @@ class StateGroupWorkerStore(
                     * room_id (str): The room ID of the predecessor room
                     * event_id (str): The ID of the tombstone event in the predecessor room
 
+                None if a predecessor key is not found, or is not a dictionary.
+
         Raises:
-            NotFoundError if the room is unknown
+            NotFoundError if the given room is unknown
         """
         # Retrieve the room's create event
         create_event = yield self.get_create_event_for_room(room_id)
 
-        # Return predecessor if present
-        return create_event.content.get("predecessor", None)
+        # Retrieve the predecessor key of the create event
+        predecessor = create_event.content.get("predecessor", None)
+
+        # Ensure the key is a dictionary
+        if not isinstance(predecessor, dict):
+            return None
+
+        return predecessor
 
     @defer.inlineCallbacks
     def get_create_event_for_room(self, room_id):
@@ -318,7 +326,7 @@ class StateGroupWorkerStore(
 
         # If we can't find the create event, assume we've hit a dead end
         if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id))
+            raise NotFoundError("Unknown room %s" % (room_id,))
 
         # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
-    """Cache for snapshots like the response of /initialSync.
-    The response of initialSync only has to be a recent snapshot of the
-    server state. It shouldn't matter to clients if it is a few minutes out
-    of date.
-
-    This caches a deferred response. Until the deferred completes it will be
-    returned from the cache. This means that if the client retries the request
-    while the response is still being computed, that original response will be
-    used rather than trying to compute a new response.
-
-    Once the deferred completes it will removed from the cache after 5 minutes.
-    We delay removing it from the cache because a client retrying its request
-    could race with us finishing computing the response.
-
-    Rather than tracking precisely how long something has been in the cache we
-    keep two generations of completed responses. Every 5 minutes discard the
-    old generation, move the new generation to the old generation, and set the
-    new generation to be empty. This means that a result will be in the cache
-    somewhere between 5 and 10 minutes.
-    """
-
-    DURATION_MS = 5 * 60 * 1000  # Cache results for 5 minutes.
-
-    def __init__(self):
-        self.pending_result_cache = {}  # Request that haven't finished yet.
-        self.prev_result_cache = {}  # The older requests that have finished.
-        self.next_result_cache = {}  # The newer requests that have finished.
-        self.time_last_rotated_ms = 0
-
-    def rotate(self, time_now_ms):
-        # Rotate once if the cache duration has passed since the last rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms += self.DURATION_MS
-
-        # Rotate again if the cache duration has passed twice since the last
-        # rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms = time_now_ms
-
-    def get(self, time_now_ms, key):
-        self.rotate(time_now_ms)
-        # This cache is intended to deduplicate requests, so we expect it to be
-        # missed most of the time. So we just lookup the key in all of the
-        # dictionaries rather than trying to short circuit the lookup if the
-        # key is found.
-        result = self.prev_result_cache.get(key)
-        result = self.next_result_cache.get(key, result)
-        result = self.pending_result_cache.get(key, result)
-        if result is not None:
-            return result.observe()
-        else:
-            return None
-
-    def set(self, time_now_ms, key, deferred):
-        self.rotate(time_now_ms)
-
-        result = ObservableDeferred(deferred)
-
-        self.pending_result_cache[key] = result
-
-        def shuffle_along(r):
-            # When the deferred completes we shuffle it along to the first
-            # generation of the result cache. So that it will eventually
-            # expire from the rotation of that cache.
-            self.next_result_cache[key] = result
-            self.pending_result_cache.pop(key, None)
-            return r
-
-        result.addBoth(shuffle_along)
-
-        return result.observe()