summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6411.feature1
-rw-r--r--changelog.d/6499.bugfix1
-rw-r--r--changelog.d/6501.misc1
-rw-r--r--changelog.d/6502.removal1
-rw-r--r--changelog.d/6503.misc1
-rw-r--r--changelog.d/6505.misc1
-rw-r--r--changelog.d/6506.misc1
-rw-r--r--changelog.d/6510.misc1
-rw-r--r--changelog.d/6514.bugfix1
-rw-r--r--docs/saml_mapping_providers.md77
-rw-r--r--docs/sample_config.yaml61
-rw-r--r--synapse/app/homeserver.py6
-rw-r--r--synapse/config/saml2_config.py186
-rw-r--r--synapse/event_auth.py8
-rw-r--r--synapse/federation/federation_client.py103
-rw-r--r--synapse/handlers/federation.py101
-rw-r--r--synapse/handlers/initial_sync.py19
-rw-r--r--synapse/handlers/saml_handler.py198
-rw-r--r--synapse/logging/context.py11
-rw-r--r--synapse/storage/data_stores/main/client_ips.py8
-rw-r--r--synapse/storage/data_stores/main/events.py23
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py8
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql4
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql16
-rw-r--r--synapse/util/caches/snapshot_cache.py94
-rw-r--r--tests/storage/test_client_ips.py49
-rw-r--r--tests/util/test_logcontext.py24
-rw-r--r--tests/util/test_snapshot_cache.py63
29 files changed, 653 insertions, 416 deletions
diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature
new file mode 100644
index 0000000000..ebea4a208d
--- /dev/null
+++ b/changelog.d/6411.feature
@@ -0,0 +1 @@
+Allow custom SAML username mapping functinality through an external provider plugin.
\ No newline at end of file
diff --git a/changelog.d/6499.bugfix b/changelog.d/6499.bugfix
new file mode 100644
index 0000000000..299feba0f8
--- /dev/null
+++ b/changelog.d/6499.bugfix
@@ -0,0 +1 @@
+Fix support for SQLite 3.7.
diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc
new file mode 100644
index 0000000000..255f45a9c3
--- /dev/null
+++ b/changelog.d/6501.misc
@@ -0,0 +1 @@
+Refactor get_events_from_store_or_dest to return a dict.
diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal
new file mode 100644
index 0000000000..0b72261d58
--- /dev/null
+++ b/changelog.d/6502.removal
@@ -0,0 +1 @@
+Remove redundant code from event authorisation implementation.
diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc
new file mode 100644
index 0000000000..e4e9a5a3d4
--- /dev/null
+++ b/changelog.d/6503.misc
@@ -0,0 +1 @@
+Move get_state methods into FederationHandler.
diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc
new file mode 100644
index 0000000000..3a75b2d9dd
--- /dev/null
+++ b/changelog.d/6505.misc
@@ -0,0 +1 @@
+Make `make_deferred_yieldable` to work with async/await.
diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc
new file mode 100644
index 0000000000..99d7a70bcf
--- /dev/null
+++ b/changelog.d/6506.misc
@@ -0,0 +1 @@
+Remove `SnapshotCache` in favour of `ResponseCache`.
diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc
new file mode 100644
index 0000000000..214f06539b
--- /dev/null
+++ b/changelog.d/6510.misc
@@ -0,0 +1 @@
+Change phone home stats to not assume there is a single database and report information about the database used by the main data store.
diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix
new file mode 100644
index 0000000000..6dc1985c24
--- /dev/null
+++ b/changelog.d/6514.bugfix
@@ -0,0 +1 @@
+Fix race which occasionally caused deleted devices to reappear.
diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md
new file mode 100644
index 0000000000..92f2380488
--- /dev/null
+++ b/docs/saml_mapping_providers.md
@@ -0,0 +1,77 @@
+# SAML Mapping Providers
+
+A SAML mapping provider is a Python class (loaded via a Python module) that
+works out how to map attributes of a SAML response object to Matrix-specific
+user attributes. Details such as user ID localpart, displayname, and even avatar
+URLs are all things that can be mapped from talking to a SSO service.
+
+As an example, a SSO service may return the email address
+"john.smith@example.com" for a user, whereas Synapse will need to figure out how
+to turn that into a displayname when creating a Matrix user for this individual.
+It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
+variations. As each Synapse configuration may want something different, this is
+where SAML mapping providers come into play.
+
+## Enabling Providers
+
+External mapping providers are provided to Synapse in the form of an external
+Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
+then tell Synapse where to look for the handler class by editing the
+`saml2_config.user_mapping_provider.module` config option.
+
+`saml2_config.user_mapping_provider.config` allows you to provide custom
+configuration options to the module. Check with the module's documentation for
+what options it provides (if any). The options listed by default are for the
+user mapping provider built in to Synapse. If using a custom module, you should
+comment these options out and use those specified by the module instead.
+
+## Building a Custom Mapping Provider
+
+A custom mapping provider must specify the following methods:
+
+* `__init__(self, parsed_config)`
+   - Arguments:
+     - `parsed_config` - A configuration object that is the return value of the
+       `parse_config` method. You should set any configuration options needed by
+       the module here.
+* `saml_response_to_user_attributes(self, saml_response, failures)`
+    - Arguments:
+      - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
+                          information from.
+      - `failures` - An `int` that represents the amount of times the returned
+                     mxid localpart mapping has failed.  This should be used
+                     to create a deduplicated mxid localpart which should be
+                     returned instead. For example, if this method returns
+                     `john.doe` as the value of `mxid_localpart` in the returned
+                     dict, and that is already taken on the homeserver, this
+                     method will be called again with the same parameters but
+                     with failures=1. The method should then return a different
+                     `mxid_localpart` value, such as `john.doe1`.
+    - This method must return a dictionary, which will then be used by Synapse
+      to build a new user. The following keys are allowed:
+       * `mxid_localpart` - Required. The mxid localpart of the new user.
+       * `displayname` - The displayname of the new user. If not provided, will default to
+                         the value of `mxid_localpart`.
+* `parse_config(config)`
+    - This method should have the `@staticmethod` decoration.
+    - Arguments:
+        - `config` - A `dict` representing the parsed content of the
+          `saml2_config.user_mapping_provider.config` homeserver config option.
+           Runs on homeserver startup. Providers should extract any option values
+           they need here.
+    - Whatever is returned will be passed back to the user mapping provider module's
+      `__init__` method during construction.
+* `get_saml_attributes(config)`
+    - This method should have the `@staticmethod` decoration.
+    - Arguments:
+        - `config` - A object resulting from a call to `parse_config`.
+    - Returns a tuple of two sets. The first set equates to the saml auth
+      response attributes that are required for the module to function, whereas
+      the second set consists of those attributes which can be used if available,
+      but are not necessary.
+
+## Synapse's Default Provider
+
+Synapse has a built-in SAML mapping provider if a custom provider isn't
+specified in the config. It is located at
+[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 10664ae8f7..4d44e631d1 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1250,33 +1250,58 @@ saml2_config:
   #
   #config_path: "CONFDIR/sp_conf.py"
 
-  # the lifetime of a SAML session. This defines how long a user has to
+  # The lifetime of a SAML session. This defines how long a user has to
   # complete the authentication process, if allow_unsolicited is unset.
   # The default is 5 minutes.
   #
   #saml_session_lifetime: 5m
 
-  # The SAML attribute (after mapping via the attribute maps) to use to derive
-  # the Matrix ID from. 'uid' by default.
+  # An external module can be provided here as a custom solution to
+  # mapping attributes returned from a saml provider onto a matrix user.
   #
-  #mxid_source_attribute: displayName
-
-  # The mapping system to use for mapping the saml attribute onto a matrix ID.
-  # Options include:
-  #  * 'hexencode' (which maps unpermitted characters to '=xx')
-  #  * 'dotreplace' (which replaces unpermitted characters with '.').
-  # The default is 'hexencode'.
-  #
-  #mxid_mapping: dotreplace
+  user_mapping_provider:
+    # The custom module's class. Uncomment to use a custom module.
+    #
+    #module: mapping_provider.SamlMappingProvider
 
-  # In previous versions of synapse, the mapping from SAML attribute to MXID was
-  # always calculated dynamically rather than stored in a table. For backwards-
-  # compatibility, we will look for user_ids matching such a pattern before
-  # creating a new account.
+    # Custom configuration values for the module. Below options are
+    # intended for the built-in provider, they should be changed if
+    # using a custom module. This section will be passed as a Python
+    # dictionary to the module's `parse_config` method.
+    #
+    config:
+      # The SAML attribute (after mapping via the attribute maps) to use
+      # to derive the Matrix ID from. 'uid' by default.
+      #
+      # Note: This used to be configured by the
+      # saml2_config.mxid_source_attribute option. If that is still
+      # defined, its value will be used instead.
+      #
+      #mxid_source_attribute: displayName
+
+      # The mapping system to use for mapping the saml attribute onto a
+      # matrix ID.
+      #
+      # Options include:
+      #  * 'hexencode' (which maps unpermitted characters to '=xx')
+      #  * 'dotreplace' (which replaces unpermitted characters with
+      #     '.').
+      # The default is 'hexencode'.
+      #
+      # Note: This used to be configured by the
+      # saml2_config.mxid_mapping option. If that is still defined, its
+      # value will be used instead.
+      #
+      #mxid_mapping: dotreplace
+
+  # In previous versions of synapse, the mapping from SAML attribute to
+  # MXID was always calculated dynamically rather than stored in a
+  # table. For backwards- compatibility, we will look for user_ids
+  # matching such a pattern before creating a new account.
   #
   # This setting controls the SAML attribute which will be used for this
-  # backwards-compatibility lookup. Typically it should be 'uid', but if the
-  # attribute maps are changed, it may be necessary to change it.
+  # backwards-compatibility lookup. Typically it should be 'uid', but if
+  # the attribute maps are changed, it may be necessary to change it.
   #
   # The default is 'uid'.
   #
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index df65d0a989..032010600a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -519,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
     # Database version
     #
 
-    stats["database_engine"] = hs.database_engine.module.__name__
-    stats["database_server_version"] = hs.database_engine.server_version
+    # This only reports info about the *main* database.
+    stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+    stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
     logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
     try:
         yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
+import logging
 
 from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
-    map_username_to_mxid_localpart,
-    mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
 
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+    "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
 
 def _dict_merge(merge_dict, into_dict):
     """Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
-        self.saml2_mxid_source_attribute = saml2_config.get(
-            "mxid_source_attribute", "uid"
-        )
-
         self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
             "grandfathered_mxid_source_attribute", "uid"
         )
 
-        saml2_config_dict = self._default_saml_config_dict()
+        # user_mapping_provider may be None if the key is present but has no value
+        ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+        # Use the default user mapping provider if not set
+        ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+        # Ensure a config is present
+        ump_dict["config"] = ump_dict.get("config") or {}
+
+        if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+            # Load deprecated options for use by the default module
+            old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+            if old_mxid_source_attribute:
+                logger.warning(
+                    "The config option saml2_config.mxid_source_attribute is deprecated. "
+                    "Please use saml2_config.user_mapping_provider.config"
+                    ".mxid_source_attribute instead."
+                )
+                ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+            old_mxid_mapping = saml2_config.get("mxid_mapping")
+            if old_mxid_mapping:
+                logger.warning(
+                    "The config option saml2_config.mxid_mapping is deprecated. Please "
+                    "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+                )
+                ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+        # Retrieve an instance of the module's class
+        # Pass the config dictionary to the module for processing
+        (
+            self.saml2_user_mapping_provider_class,
+            self.saml2_user_mapping_provider_config,
+        ) = load_module(ump_dict)
+
+        # Ensure loaded user mapping module has defined all necessary methods
+        # Note parse_config() is already checked during the call to load_module
+        required_methods = [
+            "get_saml_attributes",
+            "saml_response_to_user_attributes",
+        ]
+        missing_methods = [
+            method
+            for method in required_methods
+            if not hasattr(self.saml2_user_mapping_provider_class, method)
+        ]
+        if missing_methods:
+            raise ConfigError(
+                "Class specified by saml2_config."
+                "user_mapping_provider.module is missing required "
+                "methods: %s" % (", ".join(missing_methods),)
+            )
+
+        # Get the desired saml auth response attributes from the module
+        saml2_config_dict = self._default_saml_config_dict(
+            *self.saml2_user_mapping_provider_class.get_saml_attributes(
+                self.saml2_user_mapping_provider_config
+            )
+        )
         _dict_merge(
             merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
         )
@@ -103,22 +159,27 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "5m")
         )
 
-        mapping = saml2_config.get("mxid_mapping", "hexencode")
-        try:
-            self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
-        except KeyError:
-            raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
-    def _default_saml_config_dict(self):
+    def _default_saml_config_dict(
+        self, required_attributes: set, optional_attributes: set
+    ):
+        """Generate a configuration dictionary with required and optional attributes that
+        will be needed to process new user registration
+
+        Args:
+            required_attributes: SAML auth response attributes that are
+                necessary to function
+            optional_attributes: SAML auth response attributes that can be used to add
+                additional information to Synapse user accounts, but are not required
+
+        Returns:
+            dict: A SAML configuration dictionary
+        """
         import saml2
 
         public_baseurl = self.public_baseurl
         if public_baseurl is None:
             raise ConfigError("saml2_config requires a public_baseurl to be set")
 
-        required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
-        optional_attributes = {"displayName"}
         if self.saml2_grandfathered_mxid_source_attribute:
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
         optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
           #
           #config_path: "%(config_dir_path)s/sp_conf.py"
 
-          # the lifetime of a SAML session. This defines how long a user has to
+          # The lifetime of a SAML session. This defines how long a user has to
           # complete the authentication process, if allow_unsolicited is unset.
           # The default is 5 minutes.
           #
           #saml_session_lifetime: 5m
 
-          # The SAML attribute (after mapping via the attribute maps) to use to derive
-          # the Matrix ID from. 'uid' by default.
+          # An external module can be provided here as a custom solution to
+          # mapping attributes returned from a saml provider onto a matrix user.
           #
-          #mxid_source_attribute: displayName
-
-          # The mapping system to use for mapping the saml attribute onto a matrix ID.
-          # Options include:
-          #  * 'hexencode' (which maps unpermitted characters to '=xx')
-          #  * 'dotreplace' (which replaces unpermitted characters with '.').
-          # The default is 'hexencode'.
-          #
-          #mxid_mapping: dotreplace
-
-          # In previous versions of synapse, the mapping from SAML attribute to MXID was
-          # always calculated dynamically rather than stored in a table. For backwards-
-          # compatibility, we will look for user_ids matching such a pattern before
-          # creating a new account.
+          user_mapping_provider:
+            # The custom module's class. Uncomment to use a custom module.
+            #
+            #module: mapping_provider.SamlMappingProvider
+
+            # Custom configuration values for the module. Below options are
+            # intended for the built-in provider, they should be changed if
+            # using a custom module. This section will be passed as a Python
+            # dictionary to the module's `parse_config` method.
+            #
+            config:
+              # The SAML attribute (after mapping via the attribute maps) to use
+              # to derive the Matrix ID from. 'uid' by default.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_source_attribute option. If that is still
+              # defined, its value will be used instead.
+              #
+              #mxid_source_attribute: displayName
+
+              # The mapping system to use for mapping the saml attribute onto a
+              # matrix ID.
+              #
+              # Options include:
+              #  * 'hexencode' (which maps unpermitted characters to '=xx')
+              #  * 'dotreplace' (which replaces unpermitted characters with
+              #     '.').
+              # The default is 'hexencode'.
+              #
+              # Note: This used to be configured by the
+              # saml2_config.mxid_mapping option. If that is still defined, its
+              # value will be used instead.
+              #
+              #mxid_mapping: dotreplace
+
+          # In previous versions of synapse, the mapping from SAML attribute to
+          # MXID was always calculated dynamically rather than stored in a
+          # table. For backwards- compatibility, we will look for user_ids
+          # matching such a pattern before creating a new account.
           #
           # This setting controls the SAML attribute which will be used for this
-          # backwards-compatibility lookup. Typically it should be 'uid', but if the
-          # attribute maps are changed, it may be necessary to change it.
+          # backwards-compatibility lookup. Typically it should be 'uid', but if
+          # the attribute maps are changed, it may be necessary to change it.
           #
           # The default is 'uid'.
           #
@@ -241,23 +327,3 @@ class SAML2Config(Config):
         """ % {
             "config_dir_path": config_dir_path
         }
-
-
-DOT_REPLACE_PATTERN = re.compile(
-    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
-    username = username.lower()
-    username = DOT_REPLACE_PATTERN.sub(".", username)
-
-    # regular mxids aren't allowed to start with an underscore either
-    username = re.sub("^_", "", username)
-    return username
-
-
-MXID_MAPPER_MAP = {
-    "hexencode": map_username_to_mxid_localpart,
-    "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..c940b84470 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
     Returns:
          if the auth checks pass.
     """
+    assert isinstance(auth_events, dict)
+
     if do_size_check:
         _check_size_limits(event)
 
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
-    if auth_events is None:
-        # Oh, we don't know what the state of the room was, so we
-        # are trusting that this is allowed (at least for now)
-        logger.warning("Trusting event: %s", event.event_id)
-        return
-
     if event.type == EventTypes.Create:
         sender_domain = get_domain_from_id(event.sender)
         room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..d396e6564f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
 import itertools
 import logging
 
-from six.moves import range
-
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.utils import log_function
 from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,19 +308,12 @@ class FederationClient(FederationBase):
         return signed_pdu
 
     @defer.inlineCallbacks
-    @log_function
-    def get_state_for_room(self, destination, room_id, event_id):
-        """Requests all of the room state at a given event from a remote homeserver.
-
-        Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+    def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+        """Calls the /state_ids endpoint to fetch the state at a particular point
+        in the room, and the auth events for the given event
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            Tuple[List[str], List[str]]:  a tuple of (state event_ids, auth event_ids)
         """
         result = yield self.transport_layer.get_room_state_ids(
             destination, room_id, event_id=event_id
@@ -331,86 +322,12 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
-        fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-            destination, room_id, set(state_event_ids + auth_event_ids)
-        )
-
-        if failed_to_fetch:
-            logger.warning(
-                "Failed to fetch missing state/auth events for %s: %s",
-                room_id,
-                failed_to_fetch,
-            )
-
-        event_map = {ev.event_id: ev for ev in fetched_events}
-
-        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-
-        auth_chain.sort(key=lambda e: e.depth)
-
-        return pdus, auth_chain
-
-    @defer.inlineCallbacks
-    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
-        """Fetch events from a remote destination, checking if we already have them.
-
-        Args:
-            destination (str)
-            room_id (str)
-            event_ids (list)
-
-        Returns:
-            Deferred: A deferred resolving to a 2-tuple where the first is a list of
-            events and the second is a list of event ids that we failed to fetch.
-        """
-        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
-        signed_events = list(seen_events.values())
-
-        failed_to_fetch = set()
-
-        missing_events = set(event_ids)
-        for k in seen_events:
-            missing_events.discard(k)
-
-        if not missing_events:
-            return signed_events, failed_to_fetch
-
-        logger.debug(
-            "Fetching unknown state/auth events %s for room %s",
-            missing_events,
-            event_ids,
-        )
-
-        room_version = yield self.store.get_room_version(room_id)
-
-        batch_size = 20
-        missing_events = list(missing_events)
-        for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i : i + batch_size])
-
-            deferreds = [
-                run_in_background(
-                    self.get_pdu,
-                    destinations=[destination],
-                    event_id=e_id,
-                    room_version=room_version,
-                )
-                for e_id in batch
-            ]
-
-            res = yield make_deferred_yieldable(
-                defer.DeferredList(deferreds, consumeErrors=True)
-            )
-            for success, result in res:
-                if success and result:
-                    signed_events.append(result)
-                    batch.discard(result.event_id)
-
-            # We removed all events we successfully fetched from `batch`
-            failed_to_fetch.update(batch)
+        if not isinstance(state_event_ids, list) or not isinstance(
+            auth_event_ids, list
+        ):
+            raise Exception("invalid response from /state_ids")
 
-        return signed_events, failed_to_fetch
+        return state_event_ids, auth_event_ids
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..c0dcf9abf8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -379,11 +379,9 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self.federation_client.get_state_for_room(
-                                origin, room_id, p
-                            )
+                            ) = yield self._get_state_for_room(origin, room_id, p)
 
-                            # we want the state *after* p; get_state_for_room returns the
+                            # we want the state *after* p; _get_state_for_room returns the
                             # state *before* p.
                             remote_event = yield self.federation_client.get_pdu(
                                 [origin], p, room_version, outlier=True
@@ -584,6 +582,97 @@ class FederationHandler(BaseHandler):
                         raise
 
     @defer.inlineCallbacks
+    @log_function
+    def _get_state_for_room(self, destination, room_id, event_id):
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination (str): The remote homeserver to query for the state.
+            room_id (str): The id of the room we're interested in.
+            event_id (str): The id of the event we want the state at.
+
+        Returns:
+            Deferred[Tuple[List[EventBase], List[EventBase]]]:
+                A list of events in the state, and a list of events in the auth chain
+                for the given event.
+        """
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = yield self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
+
+        desired_events = set(state_event_ids + auth_event_ids)
+        event_map = yield self._get_events_from_store_or_dest(
+            destination, room_id, desired_events
+        )
+
+        failed_to_fetch = desired_events - event_map.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s: %s",
+                room_id,
+                failed_to_fetch,
+            )
+
+        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+
+        auth_chain.sort(key=lambda e: e.depth)
+
+        return pdus, auth_chain
+
+    @defer.inlineCallbacks
+    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+        """Fetch events from a remote destination, checking if we already have them.
+
+        Args:
+            destination (str)
+            room_id (str)
+            event_ids (Iterable[str])
+
+        Returns:
+            Deferred[dict[str, EventBase]]: A deferred resolving to a map
+            from event_id to event
+        """
+        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+
+        missing_events = set(event_ids) - fetched_events.keys()
+
+        if not missing_events:
+            return fetched_events
+
+        logger.debug(
+            "Fetching unknown state/auth events %s for room %s",
+            missing_events,
+            event_ids,
+        )
+
+        room_version = yield self.store.get_room_version(room_id)
+
+        # XXX 20 requests at once? really?
+        for batch in batch_iter(missing_events, 20):
+            deferreds = [
+                run_in_background(
+                    self.federation_client.get_pdu,
+                    destinations=[destination],
+                    event_id=e_id,
+                    room_version=room_version,
+                )
+                for e_id in batch
+            ]
+
+            res = yield make_deferred_yieldable(
+                defer.DeferredList(deferreds, consumeErrors=True)
+            )
+            for success, result in res:
+                if success and result:
+                    fetched_events[result.event_id] = result
+
+        return fetched_events
+
+    @defer.inlineCallbacks
     def _process_received_pdu(self, origin, event, state, auth_chain):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
@@ -723,7 +812,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.federation_client.get_state_for_room(
+            state, auth = yield self._get_state_for_room(
                 destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..73c110a92b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
+        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler):
             as_client_event,
             include_archived,
         )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
 
-        return self.snapshot_cache.set(
-            now_ms,
+        return self.snapshot_cache.wrap(
             key,
-            self._snapshot_all_rooms(
-                user_id, pagin_config, as_client_event, include_archived
-            ),
+            self._snapshot_all_rooms,
+            user_id,
+            pagin_config,
+            as_client_event,
+            include_archived,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
 See doc/log_contexts.rst for details on how this works.
 """
 
+import inspect
 import logging
 import threading
 import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
 
 
 def make_deferred_yieldable(deferred):
-    """Given a deferred, make it follow the Synapse logcontext rules:
+    """Given a deferred (or coroutine), make it follow the Synapse logcontext
+    rules:
 
     If the deferred has completed (or is not actually a Deferred), essentially
     does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
+    if inspect.isawaitable(deferred):
+        # If we're given a coroutine we convert it to a deferred so that we
+        # run it and find out if it immediately finishes, it it does then we
+        # don't need to fiddle with log contexts at all and can return
+        # immediately.
+        deferred = defer.ensureDeferred(deferred)
+
     if not isinstance(deferred, defer.Deferred):
         return deferred
 
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self.db.simple_upsert_txn(
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
-                        values={
+                        updatevalues={
                             "user_agent": user_agent,
                             "last_seen": last_seen,
                             "ip": ip,
                         },
-                        lock=False,
                     )
             except Exception as e:
                 # Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index da1529f6ea..998bba1aad 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1039,20 +1039,25 @@ class EventsStore(
             },
         )
 
-    @defer.inlineCallbacks
-    def _censor_redactions(self):
+    async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
 
         By censor we mean update the event_json table with the redacted event.
-
-        Returns:
-            Deferred
         """
 
         if self.hs.config.redaction_retention_period is None:
             return
 
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
         before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
 
         # We fetch all redactions that:
@@ -1074,15 +1079,15 @@ class EventsStore(
             LIMIT ?
         """
 
-        rows = yield self.db.execute(
+        rows = await self.db.execute(
             "_censor_redactions_fetch", None, sql, before_ts, 100
         )
 
         updates = []
 
         for redaction_id, event_id in rows:
-            redaction_event = yield self.get_event(redaction_id, allow_none=True)
-            original_event = yield self.get_event(
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
                 event_id, allow_rejected=True, allow_none=True
             )
 
@@ -1115,7 +1120,7 @@ class EventsStore(
                     updatevalues={"have_censored": True},
                 )
 
-        yield self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
 
     def _censor_event_txn(self, txn, event_id, pruned_json):
         """Censor an event by replacing its JSON in the event_json table with the
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index efee17b929..5177b71016 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -90,6 +90,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             "event_store_labels", self._event_store_labels
         )
 
+        self.db.updates.register_background_index_update(
+            "redactions_have_censored_ts_idx",
+            index_name="redactions_have_censored_ts",
+            table="redactions",
+            columns=["received_ts"],
+            where_clause="NOT have_censored",
+        )
+
     @defer.inlineCallbacks
     def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
index fe51b02309..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -14,4 +14,3 @@
  */
 
 ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
-CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
index 77a5eca499..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -14,7 +14,9 @@
  */
 
 ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
-CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
 
 INSERT INTO background_updates (update_name, progress_json) VALUES
   ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
-    """Cache for snapshots like the response of /initialSync.
-    The response of initialSync only has to be a recent snapshot of the
-    server state. It shouldn't matter to clients if it is a few minutes out
-    of date.
-
-    This caches a deferred response. Until the deferred completes it will be
-    returned from the cache. This means that if the client retries the request
-    while the response is still being computed, that original response will be
-    used rather than trying to compute a new response.
-
-    Once the deferred completes it will removed from the cache after 5 minutes.
-    We delay removing it from the cache because a client retrying its request
-    could race with us finishing computing the response.
-
-    Rather than tracking precisely how long something has been in the cache we
-    keep two generations of completed responses. Every 5 minutes discard the
-    old generation, move the new generation to the old generation, and set the
-    new generation to be empty. This means that a result will be in the cache
-    somewhere between 5 and 10 minutes.
-    """
-
-    DURATION_MS = 5 * 60 * 1000  # Cache results for 5 minutes.
-
-    def __init__(self):
-        self.pending_result_cache = {}  # Request that haven't finished yet.
-        self.prev_result_cache = {}  # The older requests that have finished.
-        self.next_result_cache = {}  # The newer requests that have finished.
-        self.time_last_rotated_ms = 0
-
-    def rotate(self, time_now_ms):
-        # Rotate once if the cache duration has passed since the last rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms += self.DURATION_MS
-
-        # Rotate again if the cache duration has passed twice since the last
-        # rotation.
-        if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
-            self.prev_result_cache = self.next_result_cache
-            self.next_result_cache = {}
-            self.time_last_rotated_ms = time_now_ms
-
-    def get(self, time_now_ms, key):
-        self.rotate(time_now_ms)
-        # This cache is intended to deduplicate requests, so we expect it to be
-        # missed most of the time. So we just lookup the key in all of the
-        # dictionaries rather than trying to short circuit the lookup if the
-        # key is found.
-        result = self.prev_result_cache.get(key)
-        result = self.next_result_cache.get(key, result)
-        result = self.pending_result_cache.get(key, result)
-        if result is not None:
-            return result.observe()
-        else:
-            return None
-
-    def set(self, time_now_ms, key, deferred):
-        self.rotate(time_now_ms)
-
-        result = ObservableDeferred(deferred)
-
-        self.pending_result_cache[key] = result
-
-        def shuffle_along(r):
-            # When the deferred completes we shuffle it along to the first
-            # generation of the result cache. So that it will eventually
-            # expire from the rotation of that cache.
-            self.next_result_cache[key] = result
-            self.pending_result_cache.pop(key, None)
-            return r
-
-        result.addBoth(shuffle_along)
-
-        return result.observe()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index fc279340d4..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
 
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
@@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
-
         # Force persisting to disk
         self.reactor.advance(200)
 
@@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(
             self.store.db.simple_update(
                 table="devices",
-                keyvalues={"user_id": user_id, "device_id": "device_id"},
+                keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
                 desc="test_devices_last_seen_bg_update",
             )
@@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get nulls when querying
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
@@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get the correct result again
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
@@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                     "access_token": "access_token",
                     "ip": "ip",
                     "user_agent": "user_agent",
-                    "device_id": "device_id",
+                    "device_id": device_id,
                     "last_seen": 0,
                 }
             ],
@@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But we should still get the correct values for the device
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.request, "foo-bar")
 
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_await(self):
+        # an async function which retuns an incomplete coroutine, but doesn't
+        # follow the synapse rules.
+
+        async def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            await d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
-    def setUp(self):
-        self.cache = SnapshotCache()
-        self.cache.DURATION_MS = 1
-
-    def test_get_set(self):
-        # Check that getting a missing key returns None
-        self.assertEquals(self.cache.get(0, "key"), None)
-
-        # Check that setting a key with a deferred returns
-        # a deferred that resolves when the initial deferred does
-        d = Deferred()
-        set_result = self.cache.set(0, "key", d)
-        self.assertIsNotNone(set_result)
-        self.assertFalse(set_result.called)
-
-        # Check that getting the key before the deferred has resolved
-        # returns a deferred that resolves when the initial deferred does.
-        get_result_at_10 = self.cache.get(10, "key")
-        self.assertIsNotNone(get_result_at_10)
-        self.assertFalse(get_result_at_10.called)
-
-        # Check that the returned deferreds resolve when the initial deferred
-        # does.
-        d.callback("v")
-        self.assertTrue(set_result.called)
-        self.assertTrue(get_result_at_10.called)
-
-        # Check that getting the key after the deferred has resolved
-        # before the cache expires returns a resolved deferred.
-        get_result_at_11 = self.cache.get(11, "key")
-        self.assertIsNotNone(get_result_at_11)
-        if isinstance(get_result_at_11, Deferred):
-            # The cache may return the actual result rather than a deferred
-            self.assertTrue(get_result_at_11.called)
-
-        # Check that getting the key after the deferred has resolved
-        # after the cache expires returns None
-        get_result_at_12 = self.cache.get(12, "key")
-        self.assertIsNone(get_result_at_12)